/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2015 by Contributors
 * \file org_apache_mxnet_native_c_api.cc
 * \brief JNI function implementations
 */
#include "org_apache_mxnet_native_c_api.h"  // generated by javah
#include <nnvm/c_api.h>
#include <mxnet/c_api.h>
#include <dmlc/logging.h>
#include <mxnet/ndarray.h>
#include <../src/common/cuda_utils.h>
#include <mutex>
#include <iostream>
#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "jni_helper_func.h"

JavaVM *_jvm;

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_nativeLibInit
  (JNIEnv *env, jobject obj) {
  return env->GetJavaVM(&_jvm);
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxListAllOpNames
  (JNIEnv *env, jobject obj, jobject nameList) {
  mx_uint outSize;
  const char **outArray;
  int ret = MXListAllOpNames(&outSize, &outArray);

  jclass listCls = env->FindClass("scala/collection/mutable/ListBuffer");
  jmethodID listAppend = env->GetMethodID(listCls,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
  for (size_t i = 0; i < outSize; ++i) {
    env->CallObjectMethod(nameList, listAppend, env->NewStringUTF(outArray[i]));
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_nnGetOpHandle
  (JNIEnv *env, jobject obj, jstring jopname, jobject jhandle) {
  OpHandle handle;
  const char *opname = env->GetStringUTFChars(jopname, 0);
  int ret = NNGetOpHandle(opname, &handle);
  env->ReleaseStringUTFChars(jopname, opname);

  jclass refClass = env->FindClass("org/apache/mxnet/Base$RefLong");
  jfieldID refFid = env->GetFieldID(refClass, "value", "J");
  env->SetLongField(jhandle, refFid, reinterpret_cast<jlong>(handle));

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateNone
  (JNIEnv *env, jobject obj, jobject ndArrayHandle) {
  NDArrayHandle out;
  int ret = MXNDArrayCreateNone(&out);
  SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx
  (JNIEnv *env, jobject obj, jintArray shape, jint ndim, jint devType,
    jint devId, jint delayAlloc, jint dtype, jobject ndArrayHandle) {
  jint *shapeArr = env->GetIntArrayElements(shape, NULL);
  NDArrayHandle out;
  int ret = MXNDArrayCreateEx(reinterpret_cast<mx_uint *>(shapeArr), static_cast<mx_uint>(ndim),
                              devType, devId, delayAlloc, dtype, &out);
  env->ReleaseIntArrayElements(shape, shapeArr, 0);
  SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx
  (JNIEnv *env, jobject obj, jint storageType, jintArray shape, jint ndim, jint devType,
    jint devId, jint delayAlloc, jint dtype, jint numAux, jintArray auxTypes,
    jintArray auxNdims, jintArray auxShapes, jobject ndArrayHandle) {
    jint *shapeArr = env->GetIntArrayElements(shape, NULL);
    jint *auxTypesArr = env->GetIntArrayElements(auxTypes, NULL);
    jint *auxNdimsArr = env->GetIntArrayElements(auxNdims, NULL);
    jint *auxShapesArr = env->GetIntArrayElements(auxShapes, NULL);
    NDArrayHandle out;
    int ret = MXNDArrayCreateSparseEx(storageType,
     reinterpret_cast<const mx_uint *>(shapeArr),
     static_cast<mx_uint>(ndim),
     devType, devId, delayAlloc, dtype,
     static_cast<mx_uint>(numAux),
     reinterpret_cast<int *>(auxTypesArr),
     reinterpret_cast<mx_uint *>(auxNdimsArr),
     reinterpret_cast<const mx_uint *>(auxShapesArr),  &out);
    env->ReleaseIntArrayElements(shape, shapeArr, 0);
    env->ReleaseIntArrayElements(auxTypes, auxTypesArr, 0);
    env->ReleaseIntArrayElements(auxNdims, auxNdimsArr, 0);
    env->ReleaseIntArrayElements(auxShapes, auxShapesArr, 0);
    SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
    return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayWaitAll(JNIEnv *env, jobject obj) {
  return MXNDArrayWaitAll();
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayWaitToRead
  (JNIEnv *env, jobject obj, jlong arrayPtr) {
  return MXNDArrayWaitToRead(reinterpret_cast<NDArrayHandle>(arrayPtr));
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxListFunctions
  (JNIEnv *env, jobject obj, jobject functions) {
  jclass longCls = env->FindClass("java/lang/Long");
  jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");

  // scala.collection.mutable.ListBuffer append method
  jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
  jmethodID listAppend = env->GetMethodID(listClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

  // Get function list
  FunctionHandle *outArray;
  mx_uint outSize;
  int ret = MXListFunctions(&outSize, &outArray);
  for (size_t i = 0; i < outSize; ++i) {
    env->CallObjectMethod(functions, listAppend,
                          env->NewObject(longCls, longConst, outArray[i]));
  }
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncDescribe
  (JNIEnv *env, jobject obj, jlong funcPtr, jobject nUsedVars,
    jobject nScalars, jobject nMutateVars, jobject typeMask) {
  mx_uint numUseVars;
  mx_uint numScalars;
  mx_uint numMutateVars;
  int type;
  int ret = MXFuncDescribe(reinterpret_cast<FunctionHandle>(funcPtr), &numUseVars,
                            &numScalars, &numMutateVars, &type);

  jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
  jfieldID value = env->GetFieldID(refIntClass, "value", "I");
  env->SetIntField(nUsedVars, value, static_cast<jint>(numUseVars));
  env->SetIntField(nScalars, value, static_cast<jint>(numScalars));
  env->SetIntField(nMutateVars, value, static_cast<jint>(numMutateVars));
  env->SetIntField(typeMask, value, static_cast<jint>(type));

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncGetInfo
  (JNIEnv *env, jobject obj, jlong funcPtr, jobject name, jobject desc,
    jobject numArgs, jobject argNames, jobject argTypes, jobject argDescs) {
  const char *cName;
  const char *cDesc;
  mx_uint cNumArgs;
  const char **cArgNames;
  const char **cArgTypes;
  const char **cArgDescs;
  int ret = MXFuncGetInfo(reinterpret_cast<FunctionHandle>(funcPtr),
                          &cName, &cDesc, &cNumArgs,
                          &cArgNames, &cArgTypes, &cArgDescs);

  jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
  jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");

  jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
  jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");

  // scala.collection.mutable.ListBuffer append method
  jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
  jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
      "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

  env->SetObjectField(name, valueStr, env->NewStringUTF(cName));
  env->SetObjectField(desc, valueStr, env->NewStringUTF(cDesc));
  env->SetIntField(numArgs, valueInt, static_cast<jint>(cNumArgs));
  for (size_t i = 0; i < cNumArgs; ++i) {
    env->CallObjectMethod(argNames, listAppend, env->NewStringUTF(cArgNames[i]));
    env->CallObjectMethod(argTypes, listAppend, env->NewStringUTF(cArgTypes[i]));
    env->CallObjectMethod(argDescs, listAppend, env->NewStringUTF(cArgDescs[i]));
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx
  (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray inputs,
    jlongArray outputsGiven, jobject outputs, jint numParams,
    jobjectArray paramKeys, jobjectArray paramVals, jobject outStypes) {

  const char **cParamKeys = NULL;
  const char **cParamVals = NULL;
  if (numParams > 0) {
    cParamKeys = new const char *[numParams];
    cParamVals = new const char *[numParams];
    for (int i = 0; i < numParams; i++) {
      jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
      const char *key = env->GetStringUTFChars(jkey, 0);
      cParamKeys[i] = key;
      env->DeleteLocalRef(jkey);
      jstring jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
      const char *val = env->GetStringUTFChars(jval, 0);
      cParamVals[i] = val;
      env->DeleteLocalRef(jval);
    }
  }

  int numOutputs = 0;
  jlong *cOutputsGiven = NULL;
  NDArrayHandle *cOutputs = NULL;
  const int *cOutStypes;
  if (outputsGiven) {
    cOutputsGiven = env->GetLongArrayElements(outputsGiven, NULL);
    cOutputs = reinterpret_cast<NDArrayHandle *>(cOutputsGiven);
    numOutputs = static_cast<int>(env->GetArrayLength(outputsGiven));
  }
  jlong *cInputs = env->GetLongArrayElements(inputs, NULL);
  jsize numInputs = env->GetArrayLength(inputs);
  int ret = MXImperativeInvokeEx(reinterpret_cast<AtomicSymbolCreator>(funcPtr),
                               static_cast<int>(numInputs),
                               reinterpret_cast<NDArrayHandle *>(cInputs),
                               &numOutputs,
                               &cOutputs,
                               static_cast<int>(numParams),
                               cParamKeys,
                               cParamVals,
                               &cOutStypes);
  env->ReleaseLongArrayElements(inputs, cInputs, 0);

  // release allocated memory
  if (numParams > 0) {
    for (int i = 0; i < numParams; i++) {
      jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
      env->ReleaseStringUTFChars(jkey, cParamKeys[i]);
      env->DeleteLocalRef(jkey);
      jstring jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
      env->ReleaseStringUTFChars(jval, cParamVals[i]);
      env->DeleteLocalRef(jval);
    }
    delete[] cParamKeys;
    delete[] cParamVals;
  }

  if (cOutputs) {
    jclass longCls = env->FindClass("java/lang/Long");
    jclass intCls = env->FindClass("java/lang/Integer");
    jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
    jmethodID intConst = env->GetMethodID(intCls, "<init>", "(I)V");
    // scala.collection.mutable.ListBuffer append method
    jclass listClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
    jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
        "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
    for (int i = 0; i < numOutputs; ++i) {
      env->CallObjectMethod(outputs, listAppend,
                            env->NewObject(longCls, longConst,
                            reinterpret_cast<uint64_t>(cOutputs[i])));
      env->CallObjectMethod(outStypes, listAppend,
                            env->NewObject(intCls, intConst,
                            cOutStypes[i]));
    }
  }

  if (cOutputsGiven) {
    env->ReleaseLongArrayElements(outputsGiven, cOutputsGiven, 0);
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncInvoke
  (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray useVars,
    jfloatArray scalarArgs, jlongArray mutateVars) {
  jlong *cUseVars = env->GetLongArrayElements(useVars, NULL);
  jfloat *cScalarArgs = env->GetFloatArrayElements(scalarArgs, NULL);
  jlong *cMutateVars = env->GetLongArrayElements(mutateVars, NULL);
  int ret = MXFuncInvoke(reinterpret_cast<FunctionHandle>(funcPtr),
                         reinterpret_cast<NDArrayHandle *>(cUseVars),
                         reinterpret_cast<mx_float *>(cScalarArgs),
                         reinterpret_cast<NDArrayHandle *>(cMutateVars));
  env->ReleaseLongArrayElements(useVars, cUseVars, 0);
  env->ReleaseFloatArrayElements(scalarArgs, cScalarArgs, 0);
  env->ReleaseLongArrayElements(mutateVars, cMutateVars, 0);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncInvokeEx
  (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray useVars,
    jfloatArray scalarArgs, jlongArray mutateVars,
    jint numParams, jobjectArray paramKeys, jobjectArray paramVals) {
  jlong *cUseVars = env->GetLongArrayElements(useVars, NULL);
  jfloat *cScalarArgs = env->GetFloatArrayElements(scalarArgs, NULL);
  jlong *cMutateVars = env->GetLongArrayElements(mutateVars, NULL);
  jbyte **cParamKeys = NULL;
  jbyte **cParamVals = NULL;
  if (numParams > 0) {
    cParamKeys = new jbyte *[numParams];
    cParamVals = new jbyte *[numParams];
    for (int i = 0; i < numParams; i++) {
      jbyteArray jkey = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramKeys, i));
      jbyte *cParamKey = env->GetByteArrayElements(jkey, NULL);
      cParamKeys[i] = cParamKey;
      env->DeleteLocalRef(jkey);
      jbyteArray jval = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramVals, i));
      jbyte *cParamVal = env->GetByteArrayElements(jval, NULL);
      cParamVals[i] = cParamVal;
      env->DeleteLocalRef(jval);
    }
  }
  int ret = MXFuncInvokeEx(reinterpret_cast<FunctionHandle>(funcPtr),
                           reinterpret_cast<NDArrayHandle *>(cUseVars),
                           reinterpret_cast<mx_float *>(cScalarArgs),
                           reinterpret_cast<NDArrayHandle *>(cMutateVars),
                           static_cast<int>(numParams),
                           reinterpret_cast<char **>(cParamKeys),
                           reinterpret_cast<char **>(cParamVals));
  env->ReleaseLongArrayElements(useVars, cUseVars, 0);
  env->ReleaseFloatArrayElements(scalarArgs, cScalarArgs, 0);
  env->ReleaseLongArrayElements(mutateVars, cMutateVars, 0);
  if (numParams > 0) {
    for (int i = 0; i < numParams; i++) {
      jbyteArray jkey = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramKeys, i));
      env->ReleaseByteArrayElements(jkey, cParamKeys[i], 0);
      env->DeleteLocalRef(jkey);
      jbyteArray jval = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramVals, i));
      env->ReleaseByteArrayElements(jval, cParamVals[i], 0);
      env->DeleteLocalRef(jval);
    }
    delete[] cParamKeys;
    delete[] cParamVals;
  }
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySaveRawBytes
  (JNIEnv *env, jobject obj, jlong ndArrayPtr, jobject dataBuf) {
  size_t length;
  const char *pdata;
  int ret = MXNDArraySaveRawBytes(reinterpret_cast<NDArrayHandle>(ndArrayPtr), &length, &pdata);

  // fill dataBuf
  jclass byteClass = env->FindClass("java/lang/Byte");
  jmethodID newByte = env->GetMethodID(byteClass, "<init>", "(B)V");
  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
  for (size_t i = 0; i < length; ++i) {
    jobject data = env->NewObject(byteClass, newByte, static_cast<jbyte>(pdata[i]));
    env->CallObjectMethod(dataBuf, arrayAppend, data);
    env->DeleteLocalRef(data);
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayLoadFromRawBytes
  (JNIEnv *env, jobject obj, jbyteArray bytes, jobject handleRef) {
  int size = env->GetArrayLength(bytes);
  jbyte *byteArr = env->GetByteArrayElements(bytes, NULL);
  NDArrayHandle out;
  int ret = MXNDArrayLoadFromRawBytes(reinterpret_cast<const void *>(byteArr),
                                      static_cast<size_t>(size), &out);
  env->ReleaseByteArrayElements(bytes, byteArr, 0);
  SetLongField(env, handleRef, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
  (JNIEnv *env, jobject obj, jlong ndArrayPtr, jobject ndimRef, jobject dataBuf) {
  int ndim;
  const int *pdata;
  int ret = MXNDArrayGetShapeEx(reinterpret_cast<NDArrayHandle>(ndArrayPtr), &ndim, &pdata);

  // fill dataBuf
  jclass integerClass = env->FindClass("java/lang/Integer");
  jmethodID newInteger = env->GetMethodID(integerClass, "<init>", "(I)V");

  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
  for (int i = 0; i < ndim; ++i) {
    jobject data = env->NewObject(integerClass, newInteger, pdata[i]);
    env->CallObjectMethod(dataBuf, arrayAppend, data);
    env->DeleteLocalRef(data);
  }

  // set ndimRef
  jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
  jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
  env->SetIntField(ndimRef, valueInt, ndim);

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray
  (JNIEnv *env, jobject obj, jlong dstPtr, jlong srcPtr, jint locator) {
  int ret = MXNDArraySyncCopyFromNDArray(reinterpret_cast<NDArrayHandle>(dstPtr),
                                   reinterpret_cast<NDArrayHandle>(srcPtr),
                                   locator);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyToCPU
  (JNIEnv *env, jobject obj, jlong ndArrayPtr, jbyteArray data, jint size) {
  jbyte *pdata = env->GetByteArrayElements(data, NULL);
  int ret = MXNDArraySyncCopyToCPU(reinterpret_cast<NDArrayHandle>(ndArrayPtr),
                                   reinterpret_cast<void *>(pdata), size);
  env->ReleaseByteArrayElements(data, pdata, 0);  // copy back to java array automatically
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySlice
  (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint start, jint end, jobject slicedHandle) {
  NDArrayHandle out;
  int ret = MXNDArraySlice(reinterpret_cast<NDArrayHandle>(ndArrayPtr), start, end, &out);
  SetLongField(env, slicedHandle, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt
  (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint idx, jobject jout) {
  NDArrayHandle out;
  int ret = MXNDArrayAt(reinterpret_cast<NDArrayHandle>(ndArrayPtr), idx, &out);
  SetLongField(env, jout, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64
  (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim,
   jlongArray dims, jboolean reverse, jobject reshapedHandle) {
  NDArrayHandle out;
  jlong *pdims = env->GetLongArrayElements(dims, NULL);
  int ret = MXNDArrayReshape64(reinterpret_cast<NDArrayHandle>(ndArrayPtr), ndim,
                                    reinterpret_cast<dim_t *>(pdims), reverse, &out);
  SetLongField(env, reshapedHandle, reinterpret_cast<jlong>(out));
  env->ReleaseLongArrayElements(dims, pdims, 0);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
  (JNIEnv *env, jobject obj, jlong arrayPtr, jfloatArray sourceArr, jint arrSize) {
  jfloat *sourcePtr = env->GetFloatArrayElements(sourceArr, NULL);
  int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast<NDArrayHandle>(arrayPtr),
                                     static_cast<const mx_float *>(sourcePtr), arrSize);
  env->ReleaseFloatArrayElements(sourceArr, sourcePtr, 0);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU
  (JNIEnv *env, jobject obj, jlong arrayPtr, jdoubleArray sourceArr, jint arrSize) {
  jdouble *sourcePtr = env->GetDoubleArrayElements(sourceArr, NULL);
  int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast<NDArrayHandle>(arrayPtr),
                                     static_cast<const double *>(sourcePtr), arrSize);
  env->ReleaseDoubleArrayElements(sourceArr, sourcePtr, 0);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray
  (JNIEnv *env, jobject obj, jlong arrayPtr, jobject ndArrayHandle) {
  NDArrayHandle out;
  int ret = MXNDArrayGetDataNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
                                     &out);
  SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray
  (JNIEnv *env, jobject obj, jlong arrayPtr, jint location, jobject ndArrayHandle) {
  NDArrayHandle out;
  int ret = MXNDArrayGetAuxNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
                                   static_cast<mx_uint>(location),
                                   &out);
  SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext
  (JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId) {
  int outDevType;
  int outDevId;
  int ret = MXNDArrayGetContext(reinterpret_cast<NDArrayHandle>(arrayPtr), &outDevType, &outDevId);
  jclass refClass = env->FindClass("org/apache/mxnet/Base$RefInt");
  jfieldID refFid = env->GetFieldID(refClass, "value", "I");
  env->SetIntField(devTypeId, refFid, outDevType);
  env->SetIntField(devId, refFid, outDevId);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayFree
  (JNIEnv * env, jobject obj, jlong ndArrayHandle) {
  return MXNDArrayFree(reinterpret_cast<NDArrayHandle>(ndArrayHandle));
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayLoad
  (JNIEnv * env, jobject obj, jstring jfname, jobject joutSize,
    jobject jhandles, jobject joutNameSize, jobject jnames) {
  mx_uint outSize;
  NDArrayHandle *outArr;
  mx_uint outNameSize;
  const char **outNames;

  const char *fname = env->GetStringUTFChars(jfname, 0);
  int ret = MXNDArrayLoad(fname, &outSize, &outArr, &outNameSize, &outNames);
  env->ReleaseStringUTFChars(jfname, fname);

  if (ret) {
    return ret;
  }

  // fill sizes
  jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
  jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
  env->SetIntField(joutSize, valueInt, outSize);
  env->SetIntField(joutNameSize, valueInt, outNameSize);

  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");

  // fill handles
  jclass longCls = env->FindClass("java/lang/Long");
  jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
  for (size_t i = 0; i < outSize; ++i) {
    jobject handle = env->NewObject(longCls, longConst, outArr[i]);
    env->CallObjectMethod(jhandles, arrayAppend, handle);
    env->DeleteLocalRef(handle);
  }

  // fill names
  for (size_t i = 0; i < outNameSize; ++i) {
    jstring jname = env->NewStringUTF(outNames[i]);
    env->CallObjectMethod(jnames, arrayAppend, jname);
    env->DeleteLocalRef(jname);
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySave
  (JNIEnv * env, jobject obj, jstring jfname, jlongArray jhandles, jobjectArray jkeys) {
  int numArgs = env->GetArrayLength(jhandles);
  const char **keys = NULL;
  if (jkeys != NULL) {
    keys = new const char *[numArgs];
    for (int i = 0; i < numArgs; i++) {
      jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
      const char *key = env->GetStringUTFChars(jkey, 0);
      keys[i] = key;
      env->DeleteLocalRef(jkey);
    }
  }

  const char *fname = env->GetStringUTFChars(jfname, 0);
  jlong *handles = env->GetLongArrayElements(jhandles, NULL);

  int ret = MXNDArraySave(fname, static_cast<mx_uint>(numArgs),
                          reinterpret_cast<NDArrayHandle *>(handles), keys);

  env->ReleaseLongArrayElements(jhandles, handles, 0);
  env->ReleaseStringUTFChars(jfname, fname);

  // release allocated memory
  if (jkeys != NULL) {
    for (int i = 0; i < numArgs; i++) {
      jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
      env->ReleaseStringUTFChars(jkey, keys[i]);
      env->DeleteLocalRef(jkey);
    }
    delete[] keys;
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType
  (JNIEnv * env, jobject obj, jlong jhandle, jobject jdtype) {
  int dtype;
  int ret = MXNDArrayGetDType(reinterpret_cast<NDArrayHandle>(jhandle), &dtype);
  SetIntField(env, jdtype, dtype);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType
  (JNIEnv * env, jobject obj, jlong jhandle, jobject jstype) {
  int stype;
  int ret = MXNDArrayGetStorageType(reinterpret_cast<NDArrayHandle>(jhandle), &stype);
  SetIntField(env, jstype, stype);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxInitPSEnv
  (JNIEnv *env, jobject obj, jobjectArray jkeys, jobjectArray jvals) {
  // keys and values
  int paramSize = env->GetArrayLength(jkeys);
  const char** keys = new const char*[paramSize];
  const char** vals = new const char*[paramSize];
  jstring jkey, jval;
  // use strcpy and release char* created by JNI inplace
  for (int i = 0; i < paramSize; i++) {
    jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
    const char* ckey = env->GetStringUTFChars(jkey, 0);
    keys[i] = ckey;
    env->DeleteLocalRef(jkey);

    jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
    const char* cval = env->GetStringUTFChars(jval, 0);
    vals[i] = cval;
    env->DeleteLocalRef(jval);
  }

  int ret = MXInitPSEnv(static_cast<mx_uint>(paramSize),
                        static_cast<const char**>(keys),
                        static_cast<const char**>(vals));

  // release keys and vals
  for (int i = 0; i < paramSize; i++) {
    jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
    env->ReleaseStringUTFChars(key, keys[i]);
    env->DeleteLocalRef(key);

    jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
    env->ReleaseStringUTFChars(value, vals[i]);
    env->DeleteLocalRef(value);
  }
  delete[] keys;
  delete[] vals;

  return ret;
}

extern "C" void KVStoreServerControllerFunc
  (int head, const char *body, void *handle) {
  jobject controllerObjGlb = static_cast<jobject>(handle);

  JNIEnv *env;
  _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);

  // find java controller method
  jclass ctrlClass = env->GetObjectClass(controllerObjGlb);
  jmethodID ctrlFunc = env->GetMethodID(ctrlClass, "invoke", "(ILjava/lang/String;)V");

  jstring jbody = env->NewStringUTF(body);
  env->CallVoidMethod(controllerObjGlb, ctrlFunc, head, jbody);
  env->DeleteLocalRef(jbody);

  env->DeleteLocalRef(ctrlClass);
  // FIXME(Yizhi): This function can be called multiple times,
  // can we find a way to safely destroy this object ?
  // env->DeleteGlobalRef(controllerObjGlb);
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreRunServer
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject controllerObj) {
  jobject controllerObjGlb = env->NewGlobalRef(controllerObj);
  return MXKVStoreRunServer(reinterpret_cast<KVStoreHandle>(kvStorePtr),
                            KVStoreServerControllerFunc,
                            reinterpret_cast<void *>(controllerObjGlb));
}

extern "C" void KVStoreUpdaterCallbackFunc
  (int key, NDArrayHandle recv, NDArrayHandle local, void *handle) {
  jobject updaterFuncObjGlb = static_cast<jobject>(handle);

  JNIEnv *env;
  _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);

  // find java updater method
  jclass updtClass = env->GetObjectClass(updaterFuncObjGlb);
  jmethodID updtFunc = env->GetMethodID(updtClass,
    "update", "(ILorg/apache/mxnet/NDArray;Lorg/apache/mxnet/NDArray;)V");

  // find java NDArray constructor
  jclass ndObjClass = env->FindClass("org/apache/mxnet/NDArray");
  jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "<init>", "(JZZ)V");

  jobject ndRecv = env->NewObject(ndObjClass, ndObjConstructor,
                                  reinterpret_cast<jlong>(recv), true);
  jobject ndLocal = env->NewObject(ndObjClass, ndObjConstructor,
                                   reinterpret_cast<jlong>(local), true);

  env->CallVoidMethod(updaterFuncObjGlb, updtFunc, key, ndRecv, ndLocal);

  env->DeleteLocalRef(ndLocal);
  env->DeleteLocalRef(ndRecv);
  env->DeleteLocalRef(ndObjClass);
  env->DeleteLocalRef(updtClass);
  // FIXME(Yizhi): This function can be called multiple times,
  // can we find a way to safely destroy this object ?
  // env->DeleteGlobalRef(updaterFuncObjGlb);
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreSetUpdater
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject updaterFuncObj) {
  jobject updaterFuncObjGlb = env->NewGlobalRef(updaterFuncObj);
  return MXKVStoreSetUpdater(reinterpret_cast<KVStoreHandle>(kvStorePtr),
                             KVStoreUpdaterCallbackFunc,
                             reinterpret_cast<void *>(updaterFuncObjGlb));
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreIsWorkerNode
  (JNIEnv *env, jobject obj, jobject isWorkerRef) {
  int isWorker;
  int ret = MXKVStoreIsWorkerNode(&isWorker);
  SetIntField(env, isWorkerRef, isWorker);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreCreate
  (JNIEnv *env, jobject obj, jstring name, jobject kvStoreHandle) {
  jclass refLongClass = env->FindClass("org/apache/mxnet/Base$RefLong");
  jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");

  KVStoreHandle out;
  const char *type = env->GetStringUTFChars(name, 0);
  int ret = MXKVStoreCreate(type, &out);
  env->ReleaseStringUTFChars(name, type);

  env->SetLongField(kvStoreHandle, refLongFid, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreInit
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys, jlongArray values) {
  jint *keyArray = env->GetIntArrayElements(keys, NULL);
  jlong *valueArray = env->GetLongArrayElements(values, NULL);
  int ret = MXKVStoreInit(reinterpret_cast<KVStoreHandle>(kvStorePtr),
                          static_cast<mx_uint>(len),
                          static_cast<const int *>(keyArray),
                          reinterpret_cast<NDArrayHandle *>(valueArray));
  env->ReleaseIntArrayElements(keys, keyArray, 0);
  env->ReleaseLongArrayElements(values, valueArray, 0);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreInitEx
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys, jlongArray values) {
  const char **keyArray = new const char *[len];
  for (int i = 0; i < len; i++) {
    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
    const char *key = env->GetStringUTFChars(jkey, 0);
    keyArray[i] = key;
    env->DeleteLocalRef(jkey);
  }
  jlong *valueArray = env->GetLongArrayElements(values, NULL);
  int ret = MXKVStoreInitEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
                          static_cast<mx_uint>(len),
                          keyArray,
                          reinterpret_cast<NDArrayHandle *>(valueArray));
  env->ReleaseLongArrayElements(values, valueArray, 0);
  for (int i = 0; i < len; i++) {
    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
    env->ReleaseStringUTFChars(jkey, keyArray[i]);
    env->DeleteLocalRef(jkey);
  }
  delete[] keyArray;
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePush
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
    jlongArray values, jint priority) {
  jint *keyArray = env->GetIntArrayElements(keys, NULL);
  jlong *valueArray = env->GetLongArrayElements(values, NULL);
  int ret = MXKVStorePush(reinterpret_cast<KVStoreHandle>(kvStorePtr),
                          static_cast<mx_uint>(len),
                          static_cast<const int *>(keyArray),
                          reinterpret_cast<NDArrayHandle *>(valueArray),
                          priority);
  env->ReleaseLongArrayElements(values, valueArray, 0);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePushEx
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys,
    jlongArray values, jint priority) {
  const char **keyArray = new const char *[len];
  for (int i = 0; i < len; i++) {
    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
    const char *key = env->GetStringUTFChars(jkey, 0);
    keyArray[i] = key;
    env->DeleteLocalRef(jkey);
  }
  jlong *valueArray = env->GetLongArrayElements(values, NULL);
  int ret = MXKVStorePushEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
                          static_cast<mx_uint>(len),
                          keyArray,
                          reinterpret_cast<NDArrayHandle *>(valueArray),
                          priority);
  env->ReleaseLongArrayElements(values, valueArray, 0);
  for (int i = 0; i < len; i++) {
    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
    env->ReleaseStringUTFChars(jkey, keyArray[i]);
    env->DeleteLocalRef(jkey);
  }
  delete[] keyArray;
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePull
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
    jlongArray outs, jint priority) {
  jint *keyArray = env->GetIntArrayElements(keys, NULL);
  jlong *outArray = env->GetLongArrayElements(outs, NULL);
  int ret = MXKVStorePull(reinterpret_cast<KVStoreHandle>(kvStorePtr),
                          static_cast<mx_uint>(len),
                          static_cast<const int *>(keyArray),
                          reinterpret_cast<NDArrayHandle *>(outArray),
                          priority);
  env->ReleaseIntArrayElements(keys, keyArray, 0);
  env->ReleaseLongArrayElements(outs, outArray, 0);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePullEx
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys,
    jlongArray outs, jint priority) {
  const char **keyArray = new const char *[len];
  for (int i = 0; i < len; i++) {
    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
    const char *key = env->GetStringUTFChars(jkey, 0);
    keyArray[i] = key;
    env->DeleteLocalRef(jkey);
  }
  jlong *outArray = env->GetLongArrayElements(outs, NULL);
  int ret = MXKVStorePullEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
                          static_cast<mx_uint>(len),
                          keyArray,
                          reinterpret_cast<NDArrayHandle *>(outArray),
                          priority);
  env->ReleaseLongArrayElements(outs, outArray, 0);
  for (int i = 0; i < len; i++) {
    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
    env->ReleaseStringUTFChars(jkey, keyArray[i]);
    env->DeleteLocalRef(jkey);
  }
  delete[] keyArray;
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetType
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject kvType) {
  const char *type;
  int ret = MXKVStoreGetType(reinterpret_cast<KVStoreHandle>(kvStorePtr), &type);
  jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
  jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
  env->SetObjectField(kvType, valueStr, env->NewStringUTF(type));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreSendCommmandToServers
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jint head, jstring body) {
  const char *bodyCStr = env->GetStringUTFChars(body, 0);
  int ret = MXKVStoreSendCommmandToServers(
    reinterpret_cast<KVStoreHandle>(kvStorePtr), head, bodyCStr);
  env->ReleaseStringUTFChars(body, bodyCStr);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreBarrier
  (JNIEnv *env, jobject obj, jlong kvStorePtr) {
  return MXKVStoreBarrier((KVStoreHandle)kvStorePtr);
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetGroupSize
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject sizeRef) {
  int size;
  int ret = MXKVStoreGetGroupSize(reinterpret_cast<KVStoreHandle>(kvStorePtr), &size);
  SetIntField(env, sizeRef, size);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetRank
  (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject rankRef) {
  int rank;
  int ret = MXKVStoreGetRank(reinterpret_cast<KVStoreHandle>(kvStorePtr), &rank);
  SetIntField(env, rankRef, rank);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetNumDeadNode
  (JNIEnv * env, jobject obj, jlong kvStorePtr, jint nodeId, jobject numberRef) {
  int number;
  int ret = MXKVStoreGetNumDeadNode(reinterpret_cast<KVStoreHandle>(kvStorePtr),
                                    static_cast<const int>(nodeId),
                                    &number);
  SetIntField(env, numberRef, number);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreSetBarrierBeforeExit
  (JNIEnv * env, jobject obj, jlong kvStorePtr, jint doBarrier) {
  return MXKVStoreSetBarrierBeforeExit(reinterpret_cast<KVStoreHandle>(kvStorePtr),
                                       static_cast<const int>(doBarrier));
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreFree
  (JNIEnv * env, jobject obj, jlong ptr) {
  return MXKVStoreFree(reinterpret_cast<KVStoreHandle>(ptr));
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorOutputs
  (JNIEnv *env, jobject obj, jlong executorPtr, jobject outputs) {
  mx_uint outSize;
  NDArrayHandle *out;
  int ret = MXExecutorOutputs(reinterpret_cast<ExecutorHandle>(executorPtr), &outSize, &out);

  jclass longCls = env->FindClass("java/lang/Long");
  jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");

  // fill java outputs
  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
  for (size_t i = 0; i < outSize; ++i) {
    env->CallObjectMethod(outputs, arrayAppend,
                          env->NewObject(longCls, longConst, out[i]));
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorFree
  (JNIEnv * env, jobject obj, jlong ptr) {
  return MXExecutorFree(reinterpret_cast<ExecutorHandle>(ptr));
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorForward
  (JNIEnv * env, jobject obj, jlong ptr, jint isTrain) {
  return MXExecutorForward(reinterpret_cast<ExecutorHandle>(ptr), static_cast<int>(isTrain));
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBackward
  (JNIEnv * env, jobject obj, jlong executorPtr, jlongArray grads) {
  int gradsSize = env->GetArrayLength(grads);
  jlong *gradArr = env->GetLongArrayElements(grads, NULL);
  int ret = MXExecutorBackward(reinterpret_cast<ExecutorHandle>(executorPtr),
                               static_cast<mx_uint>(gradsSize),
                               reinterpret_cast<NDArrayHandle *>(gradArr));
  env->ReleaseLongArrayElements(grads, gradArr, 0);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorReshape
  (JNIEnv * env, jobject obj,
    jint partialReshaping, jint allowUpSizing, jint devType, jint devId,
    jobjectArray jmapKeys, jintArray jmapDevTypes, jintArray jmapDevIds,
    jobjectArray jprovidedArgShapeNames, jintArray jprovidedArgShapeData,
    jintArray jprovidedArgShapeIdx, jobject jrefInArgs, jobject jrefArgGrads,
    jobject jrefAuxStates, jlong jsharedExec, jobject jrefOut) {
  CHECK(jmapKeys != NULL);
  CHECK(jprovidedArgShapeNames != NULL);

  int numMapKeys = env->GetArrayLength(jmapKeys);
  jint *mapDevTypes = env->GetIntArrayElements(jmapDevTypes, NULL);
  jint *mapDevIds = env->GetIntArrayElements(jmapDevIds, NULL);
  const char **mapKeys = NULL;
  if (numMapKeys > 0) {
    mapKeys = new const char*[numMapKeys];
    for (int i = 0; i < numMapKeys; ++i) {
      jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jmapKeys, i));
      mapKeys[i] = env->GetStringUTFChars(jkey, 0);
      env->DeleteLocalRef(jkey);
    }
  }

  int numProvidedArgShapes = env->GetArrayLength(jprovidedArgShapeNames);
  jint *providedArgShapeData = env->GetIntArrayElements(jprovidedArgShapeData, NULL);
  jint *providedArgShapeIdx = env->GetIntArrayElements(jprovidedArgShapeIdx, NULL);
  const char **providedArgShapeNames = NULL;
  if (numProvidedArgShapes > 0) {
    providedArgShapeNames = new const char*[numProvidedArgShapes];
    for (int i = 0; i < numProvidedArgShapes; ++i) {
      jstring jkey = reinterpret_cast<jstring>(
          env->GetObjectArrayElement(jprovidedArgShapeNames, i));
      providedArgShapeNames[i] = env->GetStringUTFChars(jkey, 0);
      env->DeleteLocalRef(jkey);
    }
  }

  mx_uint numInArgs = 0;
  NDArrayHandle *inArgs;
  NDArrayHandle *argGrads;

  mx_uint numAuxStates = 0;
  NDArrayHandle *auxStates;

  ExecutorHandle out;

  int ret = MXExecutorReshapeEx(partialReshaping,
                                allowUpSizing,
                                devType,
                                devId,
                                static_cast<mx_uint>(numMapKeys),
                                mapKeys,
                                static_cast<const int*>(mapDevTypes),
                                static_cast<const int*>(mapDevIds),
                                static_cast<const mx_uint>(numProvidedArgShapes),
                                providedArgShapeNames,
                                static_cast<const int*>(providedArgShapeData),
                                reinterpret_cast<const mx_uint*>(providedArgShapeIdx),
                                &numInArgs,
                                &inArgs,
                                &argGrads,
                                &numAuxStates,
                                &auxStates,
                                reinterpret_cast<ExecutorHandle>(jsharedExec),
                                &out);

  jclass longCls = env->FindClass("java/lang/Long");
  jmethodID newLong = env->GetMethodID(longCls, "<init>", "(J)V");

  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");

  for (size_t i = 0; i < numInArgs; ++i) {
    jobject inArg = env->NewObject(longCls, newLong, inArgs[i]);
    env->CallObjectMethod(jrefInArgs, arrayAppend, inArg);
    env->DeleteLocalRef(inArg);

    jobject argGrad = env->NewObject(longCls, newLong, argGrads[i]);
    env->CallObjectMethod(jrefArgGrads, arrayAppend, argGrad);
    env->DeleteLocalRef(argGrad);
  }

  for (size_t i = 0; i < numAuxStates; ++i) {
    jobject auxState = env->NewObject(longCls, newLong, auxStates[i]);
    env->CallObjectMethod(jrefAuxStates, arrayAppend, auxState);
    env->DeleteLocalRef(auxState);
  }

  SetLongField(env, jrefOut, reinterpret_cast<jlong>(out));

  // release allocated memory
  for (int i = 0; i < numMapKeys; i++) {
    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jmapKeys, i));
    env->ReleaseStringUTFChars(jkey, mapKeys[i]);
    env->DeleteLocalRef(jkey);
  }
  if (mapKeys != NULL) {
    delete[] mapKeys;
  }

  for (int i = 0; i < numProvidedArgShapes; i++) {
    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jprovidedArgShapeNames, i));
    env->ReleaseStringUTFChars(jkey, providedArgShapeNames[i]);
    env->DeleteLocalRef(jkey);
  }
  if (providedArgShapeNames != NULL) {
    delete[] providedArgShapeNames;
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorPrint
  (JNIEnv * env, jobject obj, jlong ptr, jobject debugStr) {
  const char *retDebugStr;
  int ret = MXExecutorPrint(reinterpret_cast<ExecutorHandle>(ptr), &retDebugStr);
  SetStringField(env, debugStr, retDebugStr);
  return ret;
}

extern "C" void ExecutorMonitorCallbackFunc
  (const char *name, NDArrayHandle arr, void *handle) {
  jobject callbackFuncObjGlb = static_cast<jobject>(handle);

  JNIEnv *env;
  _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);

  // find java callback method
  jclass callbackClass = env->GetObjectClass(callbackFuncObjGlb);
  jmethodID callbackFunc = env->GetMethodID(callbackClass, "invoke", "(Ljava/lang/String;J)V");

  // invoke java callback method
  jstring jname = env->NewStringUTF(name);
  env->CallVoidMethod(callbackFuncObjGlb, callbackFunc, jname, reinterpret_cast<jlong>(arr));
  env->DeleteLocalRef(jname);

  env->DeleteLocalRef(callbackClass);
  // FIXME(Yizhi): This function can be called multiple times,
  // can we find a way to safely destroy this global ref ?
  // env->DeleteGlobalRef(callbackFuncObjGlb);
}
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallback
  (JNIEnv *env, jobject obj, jlong executorPtr, jobject callbackFuncObj) {
  jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj);
  return MXExecutorSetMonitorCallback(reinterpret_cast<ExecutorHandle>(executorPtr),
                                      ExecutorMonitorCallbackFunc,
                                      reinterpret_cast<void *>(callbackFuncObjGlb));
}

JNIEXPORT jstring JNICALL Java_org_apache_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) {
  return env->NewStringUTF(MXGetLastError());
}

// IO funcs
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxListDataIters
  (JNIEnv * env, jobject obj, jobject creators) {
  jclass longCls = env->FindClass("java/lang/Long");
  jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");

  // scala.collection.mutable.ListBuffer append method
  jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
  jmethodID listAppend = env->GetMethodID(listClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

  // Get function list
  DataIterCreator *outArray;
  mx_uint outSize;
  int ret = MXListDataIters(&outSize, &outArray);
  for (size_t i = 0; i < outSize; ++i) {
    env->CallObjectMethod(creators, listAppend,
                          env->NewObject(longCls, longConst,
                                         reinterpret_cast<uint64_t>(outArray[i])));
  }
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterCreateIter
  (JNIEnv * env, jobject obj, jlong creator, jobjectArray jkeys,
    jobjectArray jvals, jobject dataIterHandleRef) {
  // keys and values
  int paramSize = env->GetArrayLength(jkeys);
  const char** keys = new const char*[paramSize];
  const char** vals = new const char*[paramSize];
  jstring jkey, jval;
  // use strcpy and release char* created by JNI inplace
  for (int i = 0; i < paramSize; i++) {
    jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
    const char* ckey = env->GetStringUTFChars(jkey, 0);
    keys[i] = ckey;
    env->DeleteLocalRef(jkey);

    jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
    const char* cval = env->GetStringUTFChars(jval, 0);
    vals[i] = cval;
    env->DeleteLocalRef(jval);
  }

  // create iter
  DataIterHandle out;
  int ret = MXDataIterCreateIter(reinterpret_cast<DataIterCreator>(creator),
                                 static_cast<mx_uint>(paramSize),
                                 static_cast<const char**>(keys),
                                 static_cast<const char**>(vals),
                                 &out);
  SetLongField(env, dataIterHandleRef, reinterpret_cast<jlong>(out));

  // release keys and vals
  for (int i = 0; i < paramSize; i++) {
    jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
    env->ReleaseStringUTFChars(key, keys[i]);
    env->DeleteLocalRef(key);

    jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
    env->ReleaseStringUTFChars(value, vals[i]);
    env->DeleteLocalRef(value);
  }
  delete[] keys;
  delete[] vals;

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetIterInfo
  (JNIEnv * env, jobject obj, jlong creator, jobject jname,
    jobject jdesc, jobject jargNames, jobject jargTypeInfos, jobject jargDescs) {
  const char* name;
  const char* description;
  mx_uint numArgs;
  const char** argNames;
  const char** argTypeInfos;
  const char** argDescs;
  int ret = MXDataIterGetIterInfo(reinterpret_cast<DataIterCreator>(creator),
                                   &name,
                                   &description,
                                   &numArgs,
                                   &argNames,
                                   &argTypeInfos,
                                   &argDescs);

  jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
  jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
  // set params
  env->SetObjectField(jname, valueStr, env->NewStringUTF(name));
  env->SetObjectField(jdesc, valueStr, env->NewStringUTF(description));
  jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
  jmethodID listAppend = env->GetMethodID(listClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
  for (size_t i = 0; i < numArgs; i++) {
    env->CallObjectMethod(jargNames, listAppend, env->NewStringUTF(argNames[i]));
    env->CallObjectMethod(jargTypeInfos, listAppend, env->NewStringUTF(argTypeInfos[i]));
    env->CallObjectMethod(jargDescs, listAppend, env->NewStringUTF(argDescs[i]));
  }
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterFree
  (JNIEnv *env, jobject obj, jlong handle) {
  int ret = MXDataIterFree(reinterpret_cast<DataIterHandle>(handle));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterBeforeFirst
  (JNIEnv *env, jobject obj, jlong handle) {
  int ret = MXDataIterBeforeFirst(reinterpret_cast<DataIterHandle>(handle));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterNext
  (JNIEnv *env, jobject obj, jlong handle, jobject out) {
  int cout;
  int ret = MXDataIterNext(reinterpret_cast<DataIterHandle>(handle), &cout);
  SetIntField(env, out, cout);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetLabel
  (JNIEnv *env, jobject obj, jlong handle, jobject ndArrayHandleRef) {
  NDArrayHandle out;
  int ret = MXDataIterGetLabel(reinterpret_cast<DataIterHandle>(handle), &out);
  SetLongField(env, ndArrayHandleRef, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetData
  (JNIEnv *env, jobject obj, jlong handle, jobject ndArrayHandleRef) {
  NDArrayHandle out;
  int ret = MXDataIterGetData(reinterpret_cast<DataIterHandle>(handle), &out);
  SetLongField(env, ndArrayHandleRef, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetIndex
  (JNIEnv *env, jobject obj, jlong handle, jobject outIndex, jobject outSize) {
  uint64_t* coutIndex;
  uint64_t coutSize;
  int ret = MXDataIterGetIndex(reinterpret_cast<DataIterHandle>(handle), &coutIndex, &coutSize);
  // set field
  SetLongField(env, outSize, static_cast<jlong>(coutSize));
  // scala.collection.mutable.ListBuffer append method
  jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
  jmethodID listAppend = env->GetMethodID(listClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

  // long class
  jclass longCls = env->FindClass("java/lang/Long");
  jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");

  for (size_t i = 0; i < coutSize; i++) {
    env->CallObjectMethod(outIndex, listAppend,
                          env->NewObject(longCls, longConst, coutIndex[i]));
  }
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetPadNum
  (JNIEnv *env, jobject obj, jlong handle, jobject pad) {
  int cpad;
  int ret = MXDataIterGetPadNum((DataIterHandle)handle, &cpad);
  SetIntField(env, pad, cpad);
  return ret;
}

// Symbol functions
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolFree
  (JNIEnv * env, jobject obj, jlong ptr) {
  return MXSymbolFree((SymbolHandle) ptr);
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAtomicSymbolCreators
  (JNIEnv *env, jobject obj, jobject symbolList) {
  mx_uint outSize;
  AtomicSymbolCreator *outArray;
  int ret = MXSymbolListAtomicSymbolCreators(&outSize, &outArray);

  jclass longCls = env->FindClass("java/lang/Long");
  jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");

  jclass listCls = env->FindClass("scala/collection/mutable/ListBuffer");
  jmethodID listAppend = env->GetMethodID(listCls,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

  for (size_t i = 0; i < outSize; ++i) {
    env->CallObjectMethod(symbolList, listAppend,
                          env->NewObject(longCls, longConst, outArray[i]));
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobject name, jobject desc, jobject numArgs,
    jobject argNames, jobject argTypes, jobject argDescs, jobject keyVarNumArgs) {

  const char *cName;
  const char *cDesc;
  mx_uint cNumArgs;
  const char **cArgNames;
  const char **cArgTypes;
  const char **cArgDescs;
  const char *cKeyVarNumArgs;

  int ret = MXSymbolGetAtomicSymbolInfo(reinterpret_cast<AtomicSymbolCreator>(symbolPtr),
                                        &cName, &cDesc, &cNumArgs,
                                        &cArgNames, &cArgTypes, &cArgDescs,
                                        &cKeyVarNumArgs);

  jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
  jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");

  jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
  jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");

  // scala.collection.mutable.ListBuffer append method
  jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
  jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
      "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

  env->SetObjectField(name, valueStr, env->NewStringUTF(cName));
  env->SetObjectField(desc, valueStr, env->NewStringUTF(cDesc));
  env->SetObjectField(keyVarNumArgs, valueStr, env->NewStringUTF(cKeyVarNumArgs));
  env->SetIntField(numArgs, valueInt, static_cast<jint>(cNumArgs));
  for (size_t i = 0; i < cNumArgs; ++i) {
    env->CallObjectMethod(argNames, listAppend, env->NewStringUTF(cArgNames[i]));
    env->CallObjectMethod(argTypes, listAppend, env->NewStringUTF(cArgTypes[i]));
    env->CallObjectMethod(argDescs, listAppend, env->NewStringUTF(cArgDescs[i]));
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateAtomicSymbol
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobjectArray paramKeys,
    jobjectArray paramVals, jobject symbolRef) {
  int paramSize = env->GetArrayLength(paramKeys);
  const char **keys = new const char*[paramSize];
  const char **vals = new const char*[paramSize];
  for (int i = 0; i < paramSize; i++) {
    jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
    const char *rawKey = env->GetStringUTFChars(key, 0);
    keys[i] = rawKey;
    env->DeleteLocalRef(key);

    jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
    const char *rawValue = env->GetStringUTFChars(value, 0);
    vals[i] = rawValue;
    env->DeleteLocalRef(value);
  }

  SymbolHandle out;
  int ret = MXSymbolCreateAtomicSymbol(reinterpret_cast<AtomicSymbolCreator>(symbolPtr),
    static_cast<mx_uint>(paramSize), keys, vals, &out);
  SetLongField(env, symbolRef, reinterpret_cast<jlong>(out));

  // release keys and vals
  for (int i = 0; i < paramSize; i++) {
    jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
    env->ReleaseStringUTFChars(key, keys[i]);
    env->DeleteLocalRef(key);

    jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
    env->ReleaseStringUTFChars(value, vals[i]);
    env->DeleteLocalRef(value);
  }
  delete[] keys;
  delete[] vals;

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolSetAttr
  (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jkey, jstring jvalue) {
  const char *ckey = env->GetStringUTFChars(jkey, 0);
  const char *cvalue = env->GetStringUTFChars(jvalue, 0);
  int ret = MXSymbolSetAttr(reinterpret_cast<SymbolHandle>(symbolPtr), ckey, cvalue);
  env->ReleaseStringUTFChars(jkey, ckey);
  env->ReleaseStringUTFChars(jvalue, cvalue);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAttrShallow
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobject joutSize, jobject jout) {
  mx_uint outSize;
  const char** out;

  int ret = MXSymbolListAttrShallow(reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &out);

  jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
  jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
  env->SetIntField(joutSize, valueInt, static_cast<jint>(outSize));

  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
  for (size_t i = 0; i < outSize * 2; ++i) {
    jstring jtmp = env->NewStringUTF(out[i]);
    env->CallObjectMethod(jout, arrayAppend, jtmp);
    env->DeleteLocalRef(jtmp);
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAttr
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobject joutSize, jobject jout) {
  mx_uint outSize;
  const char** out;

  int ret = MXSymbolListAttr(reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &out);

  jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
  jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
  env->SetIntField(joutSize, valueInt, static_cast<jint>(outSize));

  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
  for (size_t i = 0; i < outSize * 2; ++i) {
    jstring jtmp = env->NewStringUTF(out[i]);
    env->CallObjectMethod(jout, arrayAppend, jtmp);
    env->DeleteLocalRef(jtmp);
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCompose
  (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jname,
    jobjectArray jkeys, jlongArray jargs) {
  int argSize = env->GetArrayLength(jargs);
  const char **keys = NULL;
  if (jkeys != NULL) {
    keys = new const char*[argSize];
    for (int i = 0; i < argSize; i++) {
      jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
      const char *key = env->GetStringUTFChars(jkey, 0);
      keys[i] = key;
      env->DeleteLocalRef(jkey);
    }
  }
  jlong *args = env->GetLongArrayElements(jargs, NULL);
  const char *name = env->GetStringUTFChars(jname, 0);
  int ret = MXSymbolCompose(reinterpret_cast<SymbolHandle>(symbolPtr),
                            name, static_cast<mx_uint>(argSize), keys,
                            reinterpret_cast<SymbolHandle *>(args));
  env->ReleaseStringUTFChars(jname, name);
  env->ReleaseLongArrayElements(jargs, args, 0);
  // release allocated memory
  if (jkeys != NULL) {
    for (int i = 0; i < argSize; i++) {
      jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
      env->ReleaseStringUTFChars(jkey, keys[i]);
      env->DeleteLocalRef(jkey);
    }
    delete[] keys;
  }
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateVariable
  (JNIEnv *env, jobject obj, jstring jname, jobject handle) {
  SymbolHandle out;
  const char *name = env->GetStringUTFChars(jname, 0);
  int ret = MXSymbolCreateVariable(name, &out);
  env->ReleaseStringUTFChars(jname, name);
  SetLongField(env, handle, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetAttr
  (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jkey, jobject retRef, jobject successRef) {
  const char *out;
  int success;
  const char *key = env->GetStringUTFChars(jkey, 0);
  int ret = MXSymbolGetAttr(reinterpret_cast<SymbolHandle>(symbolPtr), key, &out, &success);
  env->ReleaseStringUTFChars(jkey, key);

  SetStringField(env, retRef, out);
  SetIntField(env, successRef, success);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListArguments
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobject arguments) {
  mx_uint outSize;
  const char **outStrArray;
  int ret = MXSymbolListArguments(
    reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &outStrArray);

  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
  for (size_t i = 0; i < outSize; i++) {
    jstring argument = env->NewStringUTF(outStrArray[i]);
    env->CallObjectMethod(arguments, arrayAppend, argument);
    env->DeleteLocalRef(argument);
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListOutputs
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) {
  mx_uint outSize;
  const char **outStrArray;
  int ret = MXSymbolListOutputs(reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &outStrArray);

  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
  for (size_t i = 0; i < outSize; i++) {
    jstring output = env->NewStringUTF(outStrArray[i]);
    env->CallObjectMethod(outputs, arrayAppend, output);
    env->DeleteLocalRef(output);
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAuxiliaryStates
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) {
  mx_uint outSize;
  const char **outStrArray;
  int ret = MXSymbolListAuxiliaryStates(
    reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &outStrArray);

  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
  jmethodID arrayAppend = env->GetMethodID(arrayClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
  for (size_t i = 0; i < outSize; i++) {
    jstring output = env->NewStringUTF(outStrArray[i]);
    env->CallObjectMethod(outputs, arrayAppend, output);
    env->DeleteLocalRef(output);
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCopy
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobject clonedSymbolRef) {
  SymbolHandle clonedSymbol;
  int ret = MXSymbolCopy(reinterpret_cast<SymbolHandle>(symbolPtr), &clonedSymbol);
  SetLongField(env, clonedSymbolRef, reinterpret_cast<jlong>(clonedSymbol));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateGroup
  (JNIEnv *env, jobject obj, jlongArray jsymbols, jobject out) {
  int numSymbols = env->GetArrayLength(jsymbols);
  SymbolHandle handle;
  jlong *symbols = env->GetLongArrayElements(jsymbols, NULL);
  int ret = MXSymbolCreateGroup(numSymbols, reinterpret_cast<SymbolHandle *>(symbols), &handle);
  env->ReleaseLongArrayElements(jsymbols, symbols, 0);
  SetLongField(env, out, reinterpret_cast<jlong>(handle));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolPrint
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobject out) {
  const char *outStr;
  int ret = MXSymbolPrint(reinterpret_cast<SymbolHandle>(symbolPtr), &outStr);
  SetStringField(env, out, outStr);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetOutput
  (JNIEnv *env, jobject obj, jlong symbolPtr, jint index, jobject jout) {
  SymbolHandle out;
  int ret = MXSymbolGetOutput(reinterpret_cast<SymbolHandle>(symbolPtr),
                              static_cast<mx_uint>(index), &out);
  SetLongField(env, jout, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetInternals
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobject jout) {
  SymbolHandle out;
  int ret = MXSymbolGetInternals(reinterpret_cast<SymbolHandle>(symbolPtr), &out);
  SetLongField(env, jout, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferType
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobjectArray jkeys, jintArray jvals,
    jobject jargTypeData, jobject joutTypeData, jobject jauxTypeData, jobject jcomplete) {
  int numArgs = env->GetArrayLength(jvals);
  const char **keys = NULL;
  if (jkeys != NULL) {
    keys = new const char *[numArgs];
    for (int i = 0; i < numArgs; i++) {
      jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
      const char *key = env->GetStringUTFChars(jkey, 0);
      keys[i] = key;
      env->DeleteLocalRef(jkey);
    }
  }

  mx_uint inTypeSize;
  const int *inTypeData;
  mx_uint outTypeSize;
  const int *outTypeData;
  mx_uint auxTypeSize;
  const int *auxTypeData;
  int complete;

  jint *vals = env->GetIntArrayElements(jvals, NULL);
  int ret = MXSymbolInferType(reinterpret_cast<SymbolHandle>(symbolPtr),
                              static_cast<mx_uint>(numArgs), keys,
                              static_cast<const int *>(vals),
                              &inTypeSize, &inTypeData,
                              &outTypeSize, &outTypeData,
                              &auxTypeSize, &auxTypeData,
                              &complete);
  env->ReleaseIntArrayElements(jvals, vals, 0);

  jclass integerClass = env->FindClass("java/lang/Integer");
  jmethodID newInteger = env->GetMethodID(integerClass, "<init>", "(I)V");

  jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
  jmethodID listAppend = env->GetMethodID(listClass,
    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

  for (size_t i = 0; i < inTypeSize; ++i) {
    jobject data = env->NewObject(integerClass, newInteger, inTypeData[i]);
    env->CallObjectMethod(jargTypeData, listAppend, data);
    env->DeleteLocalRef(data);
  }
  for (size_t i = 0; i < outTypeSize; ++i) {
    jobject data = env->NewObject(integerClass, newInteger, outTypeData[i]);
    env->CallObjectMethod(joutTypeData, listAppend, data);
    env->DeleteLocalRef(data);
  }
  for (size_t i = 0; i < auxTypeSize; ++i) {
    jobject data = env->NewObject(integerClass, newInteger, auxTypeData[i]);
    env->CallObjectMethod(jauxTypeData, listAppend, data);
    env->DeleteLocalRef(data);
  }

  SetIntField(env, jcomplete, complete);

  // release allocated memory
  if (jkeys != NULL) {
    for (int i = 0; i < numArgs; i++) {
      jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
      env->ReleaseStringUTFChars(jkey, keys[i]);
      env->DeleteLocalRef(jkey);
    }
    delete[] keys;
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolSaveToJSON
  (JNIEnv *env, jobject obj, jlong symbolPtr, jobject jout) {
  const char *out;
  int ret = MXSymbolSaveToJSON(reinterpret_cast<SymbolHandle>(symbolPtr), &out);
  SetStringField(env, jout, out);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromJSON
  (JNIEnv *env, jobject obj, jstring json, jobject jhandleRef) {
  const char *str = env->GetStringUTFChars(json, 0);
  SymbolHandle out;
  int ret = MXSymbolCreateFromJSON(str, &out);
  SetLongField(env, jhandleRef, reinterpret_cast<jlong>(out));
  env->ReleaseStringUTFChars(json, str);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolSaveToFile
  (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jfname) {
  const char *fname = env->GetStringUTFChars(jfname, 0);
  int ret = MXSymbolSaveToFile(reinterpret_cast<SymbolHandle>(symbolPtr), fname);
  env->ReleaseStringUTFChars(jfname, fname);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromFile
  (JNIEnv *env, jobject obj, jstring jfname, jobject jhandleRef) {
  const char *fname = env->GetStringUTFChars(jfname, 0);
  SymbolHandle out;
  int ret = MXSymbolCreateFromFile(fname, &out);
  SetLongField(env, jhandleRef, reinterpret_cast<jlong>(out));
  env->ReleaseStringUTFChars(jfname, fname);
  return ret;
}

int FillSymbolInferShape
  (JNIEnv *env, jmethodID listAppend, jobject joutData,
    int shapeSize, const int *shapeNdim, const int **shapeData) {
  for (int i = 0; i < shapeSize; ++i) {
    jintArray jshape = NULL;
    if (shapeNdim[i] >= 0) {
      jshape = env->NewIntArray(shapeNdim[i]);
      if (jshape == NULL) {
        // TODO(Yizhi): out of memory error thrown, return a specific error code ?
        return -1;
      }
      env->SetIntArrayRegion(jshape, 0, shapeNdim[i], reinterpret_cast<const jint *>(shapeData[i]));
    }
    env->CallObjectMethod(joutData, listAppend, jshape);
    env->DeleteLocalRef(jshape);
  }
  return 0;
}

int SymbolInferShapeHelper(JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs,
                            jobjectArray jkeys, jintArray jargIndPtr, jintArray jargShapeData,
                            jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData,
                            jobject jcomplete, bool partial) {
  const char **keys = NULL;
  if (jkeys != NULL) {
    keys = new const char *[jnumArgs];
    for (int i = 0; i < jnumArgs; i++) {
      jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
      const char *key = env->GetStringUTFChars(jkey, 0);
      keys[i] = key;
      env->DeleteLocalRef(jkey);
    }
  }

  mx_uint inShapeSize;
  const int *inShapeNdim;
  const int **inShapeData;

  mx_uint outShapeSize;
  const int *outShapeNdim;
  const int **outShapeData;

  mx_uint auxShapeSize;
  const int *auxShapeNdim;
  const int **auxShapeData;

  int complete;

  jint *argIndPtr = env->GetIntArrayElements(jargIndPtr, NULL);
  jint *argShapeData = env->GetIntArrayElements(jargShapeData, NULL);
  int ret;
  if (!partial) {
    ret = MXSymbolInferShapeEx(reinterpret_cast<SymbolHandle>(symbolPtr),
                               static_cast<mx_uint>(jnumArgs),
                               keys,
                               reinterpret_cast<mx_uint *>(argIndPtr),
                               reinterpret_cast<const int *>(argShapeData),
                               &inShapeSize,
                               &inShapeNdim,
                               &inShapeData,
                               &outShapeSize,
                               &outShapeNdim,
                               &outShapeData,
                               &auxShapeSize,
                               &auxShapeNdim,
                               &auxShapeData,
                               &complete);
  } else {
    ret = MXSymbolInferShapePartialEx(reinterpret_cast<SymbolHandle>(symbolPtr),
                                      static_cast<mx_uint>(jnumArgs),
                                      keys,
                                      reinterpret_cast<mx_uint *>(argIndPtr),
                                      reinterpret_cast<const int *>(argShapeData),
                                      &inShapeSize,
                                      &inShapeNdim,
                                      &inShapeData,
                                      &outShapeSize,
                                      &outShapeNdim,
                                      &outShapeData,
                                      &auxShapeSize,
                                      &auxShapeNdim,
                                      &auxShapeData,
                                      &complete);
  }
  env->ReleaseIntArrayElements(jargShapeData, argShapeData, 0);
  env->ReleaseIntArrayElements(jargIndPtr, argIndPtr, 0);

  if (ret == 0) {
    jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
    jmethodID listAppend = env->GetMethodID(listClass,
      "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

    if (FillSymbolInferShape(
          env, listAppend, jinShapeData, inShapeSize, inShapeNdim, inShapeData)) {
      // TODO(Yizhi): out of memory error thrown, return a specific error code ?
      return -1;
    }
    if (FillSymbolInferShape(
          env, listAppend, joutShapeData, outShapeSize, outShapeNdim, outShapeData)) {
      // TODO(Yizhi): out of memory error thrown, return a specific error code ?
      return -1;
    }
    if (FillSymbolInferShape(
          env, listAppend, jauxShapeData, auxShapeSize, auxShapeNdim, auxShapeData)) {
      // TODO(Yizhi): out of memory error thrown, return a specific error code ?
      return -1;
    }

    SetIntField(env, jcomplete, complete);
  }

  // release allocated memory
  if (jkeys != NULL) {
    for (int i = 0; i < jnumArgs; i++) {
      jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
      env->ReleaseStringUTFChars(jkey, keys[i]);
      env->DeleteLocalRef(jkey);
    }
    delete[] keys;
  }

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
  (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys,
    jintArray jargIndPtr, jintArray jargShapeData,
    jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) {

  return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys, jargIndPtr, jargShapeData,
                                jinShapeData, joutShapeData, jauxShapeData, jcomplete, false);
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial
  (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys,
    jintArray jargIndPtr, jintArray jargShapeData,
    jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) {

  return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys, jargIndPtr, jargShapeData,
                                jinShapeData, joutShapeData, jauxShapeData, jcomplete, true);
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBindX
  (JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint deviceID, jint numCtx,
    jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray jctxMapDevIDs, jint numArgs,
    jlongArray jargsHandle, jlongArray jargsGradHandle, jintArray jreqsArray,
    jlongArray jauxArgsHandle, jobject jexecOut) {
  ExecutorHandle out;
  int auxStatesLen = env->GetArrayLength(jauxArgsHandle);

  const char **mapKeys = new const char *[numCtx];
  for (int i = 0; i < numCtx; i++) {
    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jctxMapKeys, i));
    const char *key = env->GetStringUTFChars(jkey, 0);
    mapKeys[i] = key;
    env->DeleteLocalRef(jkey);
  }
  jlong *auxStates = env->GetLongArrayElements(jauxArgsHandle, NULL);
  jint *gradReqType = env->GetIntArrayElements(jreqsArray, NULL);
  jlong *inArgs = env->GetLongArrayElements(jargsHandle, NULL);
  jlong *argGradStore = env->GetLongArrayElements(jargsGradHandle, NULL);
  jint *mapDevTypes = env->GetIntArrayElements(jctxMapDevTypes, NULL);
  jint *mapDevIDs = env->GetIntArrayElements(jctxMapDevIDs, NULL);
  int ret = MXExecutorBindX(reinterpret_cast<SymbolHandle>(symbolPtr),
                            deviceTypeId,
                            deviceID,
                            static_cast<mx_uint>(numCtx),
                            mapKeys,
                            mapDevTypes,
                            mapDevIDs,
                            static_cast<mx_uint>(numArgs),
                            reinterpret_cast<NDArrayHandle *>(inArgs),
                            reinterpret_cast<NDArrayHandle *>(argGradStore),
                            reinterpret_cast<mx_uint *>(gradReqType),
                            static_cast<mx_uint>(auxStatesLen),
                            reinterpret_cast<NDArrayHandle *>(auxStates),
                            &out);
  env->ReleaseIntArrayElements(jctxMapDevIDs, mapDevIDs, 0);
  env->ReleaseIntArrayElements(jctxMapDevTypes, mapDevTypes, 0);
  env->ReleaseLongArrayElements(jargsGradHandle, argGradStore, 0);
  env->ReleaseLongArrayElements(jargsHandle, inArgs, 0);
  env->ReleaseIntArrayElements(jreqsArray, gradReqType, 0);
  env->ReleaseLongArrayElements(jauxArgsHandle, auxStates, 0);
  for (int i = 0; i < numCtx; i++) {
    jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i);
    env->ReleaseStringUTFChars(jkey, mapKeys[i]);
    env->DeleteLocalRef(jkey);
  }
  delete[] mapKeys;

  SetLongField(env, jexecOut, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBindEX
  (JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint deviceID, jint numCtx,
    jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray jctxMapDevIDs, jint numArgs,
    jlongArray jargsHandle, jlongArray jargsGradHandle, jintArray jreqsArray,
    jlongArray jauxArgsHandle, jlong jsharedExec, jobject jexecOut) {
  ExecutorHandle out;
  int auxStatesLen = env->GetArrayLength(jauxArgsHandle);
  ExecutorHandle sharedExec = nullptr;
  if ((int32_t)jsharedExec != 0) sharedExec = reinterpret_cast<ExecutorHandle>(jsharedExec);

  const char **mapKeys = new const char *[numCtx];
  for (int i = 0; i < numCtx; i++) {
    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jctxMapKeys, i));
    const char *key = env->GetStringUTFChars(jkey, 0);
    mapKeys[i] = key;
    env->DeleteLocalRef(jkey);
  }
  jlong *auxStates = env->GetLongArrayElements(jauxArgsHandle, NULL);
  jint *gradReqType = env->GetIntArrayElements(jreqsArray, NULL);
  jlong *inArgs = env->GetLongArrayElements(jargsHandle, NULL);
  jlong *argGradStore = env->GetLongArrayElements(jargsGradHandle, NULL);
  jint *mapDevTypes = env->GetIntArrayElements(jctxMapDevTypes, NULL);
  jint *mapDevIDs = env->GetIntArrayElements(jctxMapDevIDs, NULL);
  int ret = MXExecutorBindEX(reinterpret_cast<SymbolHandle>(symbolPtr),
                            deviceTypeId,
                            deviceID,
                            static_cast<mx_uint>(numCtx),
                            mapKeys,
                            mapDevTypes,
                            mapDevIDs,
                            static_cast<mx_uint>(numArgs),
                            reinterpret_cast<NDArrayHandle *>(inArgs),
                            reinterpret_cast<NDArrayHandle *>(argGradStore),
                            reinterpret_cast<mx_uint *>(gradReqType),
                            static_cast<mx_uint>(auxStatesLen),
                            reinterpret_cast<NDArrayHandle *>(auxStates),
                            sharedExec,
                            &out);
  env->ReleaseIntArrayElements(jctxMapDevIDs, mapDevIDs, 0);
  env->ReleaseIntArrayElements(jctxMapDevTypes, mapDevTypes, 0);
  env->ReleaseLongArrayElements(jargsGradHandle, argGradStore, 0);
  env->ReleaseLongArrayElements(jargsHandle, inArgs, 0);
  env->ReleaseIntArrayElements(jreqsArray, gradReqType, 0);
  env->ReleaseLongArrayElements(jauxArgsHandle, auxStates, 0);
  for (int i = 0; i < numCtx; i++) {
    jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i);
    env->ReleaseStringUTFChars(jkey, mapKeys[i]);
    env->DeleteLocalRef(jkey);
  }
  delete[] mapKeys;

  SetLongField(env, jexecOut, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRandomSeed
  (JNIEnv *env, jobject obj, jint seed) {
  return MXRandomSeed(seed);
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNotifyShutdown
  (JNIEnv *env, jobject obj) {
  return MXNotifyShutdown();
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterCreate
  (JNIEnv *env, jobject obj, jstring juri, jobject handle) {
  RecordIOHandle out;
  const char *uri = env->GetStringUTFChars(juri, 0);
  int ret = MXRecordIOWriterCreate(uri, &out);
  env->ReleaseStringUTFChars(juri, uri);
  SetLongField(env, handle, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderCreate
  (JNIEnv *env, jobject obj, jstring juri, jobject handle) {
  RecordIOHandle out;
  const char *uri = env->GetStringUTFChars(juri, 0);
  int ret = MXRecordIOReaderCreate(uri, &out);
  env->ReleaseStringUTFChars(juri, uri);
  SetLongField(env, handle, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterFree
  (JNIEnv *env, jobject obj, jlong handle) {
  RecordIOHandle recordIOHandle = reinterpret_cast<RecordIOHandle>(handle);
  int ret = MXRecordIOWriterFree(recordIOHandle);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderFree
  (JNIEnv *env, jobject obj, jlong handle) {
  RecordIOHandle recordIOHandle = reinterpret_cast<RecordIOHandle>(handle);
  int ret = MXRecordIOReaderFree(&recordIOHandle);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterWriteRecord
  (JNIEnv *env, jobject obj, jlong handle, jstring jbuf, jint size) {
  const char *buf = env->GetStringUTFChars(jbuf, 0);
  RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
  int ret = MXRecordIOWriterWriteRecord(recordIOHandle, buf, size);
  env->ReleaseStringUTFChars(jbuf, buf);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderReadRecord
  (JNIEnv *env, jobject obj, jlong handle, jobject buf) {
  RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
  size_t size;
  char const  *out;
  int ret = MXRecordIOReaderReadRecord(recordIOHandle, &out, &size);
  SetStringField(env, buf, out);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterTell
  (JNIEnv *env, jobject obj, jlong handle, jobject jpos) {
  RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
  size_t pos;
  int ret = MXRecordIOWriterTell(recordIOHandle, &pos);
  SetIntField(env, jpos, pos);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderSeek
  (JNIEnv *env, jobject obj, jlong handle, jint pos) {
  RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
  int ret = MXRecordIOReaderSeek(recordIOHandle, pos);
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcCreate
  (JNIEnv *env, jobject obj, jstring jname, jobjectArray jinputNames,
    jobjectArray joutputNames, jlongArray jinputs, jlongArray joutputs,
    jstring jkernel, jobject jhandle) {
  RtcHandle out;
  char *name = const_cast<char *>(env->GetStringUTFChars(jname, 0));
  int num_input = env->GetArrayLength(jinputNames);
  char **inputNames = new char *[num_input];
  for (int i = 0; i < num_input; i++) {
    jstring jinname = reinterpret_cast<jstring>(env->GetObjectArrayElement(jinputNames, i));
    char *inname = const_cast<char *>(env->GetStringUTFChars(jinname, 0));
    inputNames[i] = inname;
    env->DeleteLocalRef(jinname);
  }
  int num_output = env->GetArrayLength(joutputNames);
  char **outputNames = new char *[num_output];
  for (int i = 0; i < num_output; i++) {
    jstring joutname = reinterpret_cast<jstring>(env->GetObjectArrayElement(joutputNames, i));
    char *outname = const_cast<char *>(env->GetStringUTFChars(joutname, 0));
    outputNames[i] = outname;
    env->DeleteLocalRef(joutname);
  }
  jlong *inputs = env->GetLongArrayElements(jinputs, NULL);
  jlong *outputs = env->GetLongArrayElements(joutputs, NULL);
  char *kernel = const_cast<char *>(env->GetStringUTFChars(jkernel, 0));

  int ret = MXRtcCreate(name,
                        static_cast<mx_uint>(num_input),
                        static_cast<mx_uint>(num_output),
                        inputNames,
                        outputNames,
                        reinterpret_cast<NDArrayHandle *>(inputs),
                        reinterpret_cast<NDArrayHandle *>(outputs),
                        kernel,
                        &out);

  // release allocated memory
  env->ReleaseStringUTFChars(jname, name);
  env->ReleaseStringUTFChars(jkernel, kernel);
  env->ReleaseLongArrayElements(jinputs, inputs, 0);
  env->ReleaseLongArrayElements(joutputs, outputs, 0);
  for (int i = 0; i < num_input; i++) {
    jstring jinname = reinterpret_cast<jstring>(env->GetObjectArrayElement(jinputNames, i));
    env->ReleaseStringUTFChars(jinname, inputNames[i]);
    env->DeleteLocalRef(jinname);
  }
  delete[] inputNames;
  for (int i = 0; i < num_output; i++) {
    jstring joutname = reinterpret_cast<jstring>(env->GetObjectArrayElement(joutputNames, i));
    env->ReleaseStringUTFChars(joutname, outputNames[i]);
    env->DeleteLocalRef(joutname);
  }
  delete[] outputNames;

  SetLongField(env, jhandle, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcPush
  (JNIEnv *env, jobject obj, jlong jhandle, jlongArray jinputs,
    jlongArray joutputs, jint gridDimX, jint gridDimY, jint gridDimZ,
    jint blockDimX, jint blockDimY, jint blockDimZ) {

  RtcHandle handle = reinterpret_cast<RtcHandle>(jhandle);
  jlong *inputs = env->GetLongArrayElements(jinputs, NULL);
  jlong *outputs = env->GetLongArrayElements(joutputs, NULL);
  int num_input = env->GetArrayLength(jinputs);
  int num_output = env->GetArrayLength(joutputs);

  int ret = MXRtcPush(handle,
                      static_cast<mx_uint>(num_input),
                      static_cast<mx_uint>(num_output),
                      reinterpret_cast<NDArrayHandle *>(inputs),
                      reinterpret_cast<NDArrayHandle *>(outputs),
                      static_cast<mx_uint>(gridDimX),
                      static_cast<mx_uint>(gridDimY),
                      static_cast<mx_uint>(gridDimZ),
                      static_cast<mx_uint>(blockDimX),
                      static_cast<mx_uint>(blockDimY),
                      static_cast<mx_uint>(blockDimZ));

  // release allocated memory
  env->ReleaseLongArrayElements(jinputs, inputs, 0);
  env->ReleaseLongArrayElements(joutputs, outputs, 0);

  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcFree
  (JNIEnv *env, jobject obj, jlong jhandle) {
  RtcHandle handle = reinterpret_cast<RtcHandle>(jhandle);
  int ret = MXRtcFree(handle);
  return ret;
}

// store the user defined CustomOpProp object reference with its name
std::unordered_map<std::string, jobject> globalOpPropMap;
// store the user defined CustomOp object reference with its name
std::unordered_map<std::string, jobject> globalOpMap;
// used for thread safty when insert  elements into
// or erase elements from the std::unordered_map
std::mutex mutex_opprop;
std::mutex mutex_op;

// Registers a custom operator when called
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxCustomOpRegister
  (JNIEnv *env, jobject obj, jstring jregName, jobject jopProp) {
  const char *regName = env->GetStringUTFChars(jregName, 0);
  std::string key(regName);

  std::unique_lock<std::mutex> lock(mutex_opprop);
  globalOpPropMap.insert({ key, env->NewGlobalRef(jopProp) });
  lock.unlock();

  // lambda function to initialize the operator and create all callbacks
  auto creatorLambda = [](const char *opType, const int numKwargs,
    const char  **keys, const char **values, MXCallbackList *ret) {
    int success = true;

    std::string opPropKey(opType);
    if (globalOpPropMap.find(opPropKey) == globalOpPropMap.end()) {
      LOG(WARNING) << "CustomOpProp: " << opPropKey << " not found";
      success = false;
    } else {
      JNIEnv *env;
      _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
      jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(opPropKey));
      jmethodID midInit = env->GetMethodID(opPropClass,
        "init", "([Ljava/lang/String;[Ljava/lang/String;)V");
      if (NULL == midInit) {
        LOG(WARNING) << "could not find CustomOpProp method init.";
        success = false;
      } else {
        // call init and set CustomOpProp.kwargs
        jclass strCls = env->FindClass("Ljava/lang/String;");
        jobjectArray keysArr = env->NewObjectArray(numKwargs, strCls, NULL);
        jobjectArray valuesArr = env->NewObjectArray(numKwargs, strCls, NULL);
        for (int i = 0; i < numKwargs; ++i) {
          jstring keyStr = env->NewStringUTF(keys[i]);
          jstring valueStr = env->NewStringUTF(values[i]);
          env->SetObjectArrayElement(keysArr, i, keyStr);
          env->SetObjectArrayElement(valuesArr, i, valueStr);
          env->DeleteLocalRef(keyStr);
          env->DeleteLocalRef(valueStr);
        }
        env->CallVoidMethod(globalOpPropMap.at(opPropKey), midInit, keysArr, valuesArr);
        env->DeleteLocalRef(keysArr);
        env->DeleteLocalRef(valuesArr);
      }
      _jvm->DetachCurrentThread();
    }

    // list_arguments callback
    auto opPropListArgument = [](char ***args, void *state) {
      int success = true;
      std::string key(reinterpret_cast<char *>(state));
      if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
        LOG(WARNING) << "CustomOpProp: " << key << " not found";
        success = false;
      } else {
        JNIEnv *env;
        _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
        jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
        jmethodID midListArguments = env->GetMethodID(
          opPropClass, "listArguments", "()[Ljava/lang/String;");
        if (NULL == midListArguments) {
          LOG(WARNING) << "could not find opProp method listArguments.";
          success = false;
        } else {
          jobjectArray jargs =(jobjectArray)(env->CallObjectMethod(
            globalOpPropMap.at(key), midListArguments));
          int len = env->GetArrayLength(jargs);
          *args = new char *[len+1];
          for (int i = 0; i < len; ++i) {
            jstring jarg = reinterpret_cast<jstring>(env->GetObjectArrayElement(jargs, i));
            const char *arg = env->GetStringUTFChars(jarg, 0);
            (*args)[i] = const_cast<char *>(arg);
            env->DeleteLocalRef(jarg);
          }
          (*args)[len] = NULL;
        }
        _jvm->DetachCurrentThread();
      }
      return success;
    };

    // list_outputs callback
    auto opPropListOutputs = [](char ***outputs, void *state) {
      int success = true;
      std::string key(reinterpret_cast<char *>(state));
      if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
        LOG(WARNING) << "CustomOpProp: " << key << " not found";
        success = false;
      } else {
        JNIEnv *env;
        _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
        jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
        jmethodID midListOutputs = env->GetMethodID(
          opPropClass, "listOutputs", "()[Ljava/lang/String;");
        if (NULL == midListOutputs) {
          LOG(WARNING) << "could not find opProp method listOutputs.";
          success = false;
        } else {
          jobjectArray joutputs = (jobjectArray)(env->CallObjectMethod(
            globalOpPropMap.at(key), midListOutputs));
          int len = env->GetArrayLength(joutputs);
          *outputs = new char *[len + 1];
          for (int i = 0; i < len; ++i) {
            jstring joutput = reinterpret_cast<jstring>(env->GetObjectArrayElement(joutputs, i));
            const char *output = env->GetStringUTFChars(joutput, 0);
            (*outputs)[i] = const_cast<char *>(output);
            env->DeleteLocalRef(joutput);
          }
          (*outputs)[len] = NULL;
        }
        _jvm->DetachCurrentThread();
      }
      return success;
    };

    // list_auxiliary_states callback
    auto opPropListAuxStates = [](char ***auxs, void *state) {
      int success = true;
      std::string key(reinterpret_cast<char *>(state));
      if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
        LOG(WARNING) << "CustomOpProp: " << key << " not found";
        success = false;
      } else {
        JNIEnv *env;
        _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
        jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
        jmethodID midListAuxStates = env->GetMethodID(
          opPropClass, "listAuxiliaryStates", "()[Ljava/lang/String;");
        if (NULL == midListAuxStates) {
          LOG(WARNING) << "could not find opProp method listAuxiliaryStates.";
          success = false;
        } else {
          auto obj = env->CallObjectMethod(globalOpPropMap.at(key), midListAuxStates);
          if (obj != NULL) {
            jobjectArray jauxs = (jobjectArray)obj;
            int len = env->GetArrayLength(jauxs);
            *auxs = new char *[len+1];
            for (int i = 0; i < len; ++i) {
              jstring jaux = reinterpret_cast<jstring>(env->GetObjectArrayElement(jauxs, i));
              const char *aux = env->GetStringUTFChars(jaux, 0);
              (*auxs)[i] = const_cast<char *>(aux);
              env->DeleteLocalRef(jaux);
            }
            (*auxs)[len] = NULL;
          } else {
            (*auxs) = new char *[1];
            (*auxs)[0] = NULL;
          }
        }
        _jvm->DetachCurrentThread();
      }
      return success;
    };

    // declare_backward_dependency callback
    auto opPropDeclareBkDep = [](const int *outGrad, const int *inData,
      const int *outData, int *numDeps, int **rdeps, void *state) {
      int success = true;
      std::string key(reinterpret_cast<char *>(state));
      if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
        LOG(WARNING) << "CustomOpProp: " << key << " not found";
        success = false;
      } else {
        JNIEnv *env;
        _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
        jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
        jmethodID midDeclareBkDep = env->GetMethodID(
          opPropClass, "declareBackwardDependency", "([I[I[I)[I");
        if (NULL == midDeclareBkDep) {
          LOG(WARNING) << "could not find opProp method declareBackwardDependency.";
          success = false;
        } else {
          jmethodID midListOutputs = env->GetMethodID(
            opPropClass, "listOutputs", "()[Ljava/lang/String;");
          jobjectArray joutputs = (jobjectArray)(env->CallObjectMethod(
            globalOpPropMap.at(key), midListOutputs));
          int outLen = env->GetArrayLength(joutputs);
          jmethodID midListArguments = env->GetMethodID(
            opPropClass, "listArguments", "()[Ljava/lang/String;");
          jobjectArray jargs = (jobjectArray)(env->CallObjectMethod(
            globalOpPropMap.at(key), midListArguments));
          int intLen = env->GetArrayLength(jargs);

          jintArray outGradArr = env->NewIntArray(outLen);
          env->SetIntArrayRegion(outGradArr, (jsize)0, (jsize)outLen, outGrad);
          jintArray inDataArr = env->NewIntArray(intLen);
          env->SetIntArrayRegion(inDataArr, (jsize)0, (jsize)intLen, inData);
          jintArray outDataArr = env->NewIntArray(outLen);
          env->SetIntArrayRegion(outDataArr, (jsize)0, (jsize)outLen, outData);

          auto obj = env->CallObjectMethod(globalOpPropMap.at(key), midDeclareBkDep,
                                                   outGradArr,
                                                   inDataArr,
                                                   outDataArr);
          jintArray jrdeps = (jintArray)obj;
          jint *rdepsArr = env->GetIntArrayElements(jrdeps, NULL);

          *numDeps = env->GetArrayLength(jrdeps);
          *rdeps = new int[(* numDeps)];
          for (int i = 0 ; i < (*numDeps); ++i) {
            (*rdeps)[i] = rdepsArr[i];
          }
          env->DeleteLocalRef(outGradArr);
          env->DeleteLocalRef(inDataArr);
          env->DeleteLocalRef(outDataArr);
          env->ReleaseIntArrayElements(jrdeps, rdepsArr, 0);
        }
        _jvm->DetachCurrentThread();
      }
      return success;
    };

    // infer_shape callback
    auto opPropInferShape = [](int numInput, int *ndims,
      unsigned **shapes, void *state) {
      int success = true;
      std::string key(reinterpret_cast<char *>(state));
      if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
        LOG(WARNING) << "CustomOpProp: " << key << " not found";
        success = false;
      } else {
        JNIEnv *env;
        _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
        jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
        jmethodID midInferShape = env->GetMethodID(opPropClass, "inferShapeEntry", "(I[[I)[[I");
        if (NULL == midInferShape) {
          LOG(WARNING) << "could not find opProp method inferShapeEntry.";
          success = false;
        } else {
          jmethodID midListArguments = env->GetMethodID(
            opPropClass, "listArguments", "()[Ljava/lang/String;");
          jobjectArray jargs = (jobjectArray)(env->CallObjectMethod(
            globalOpPropMap.at(key), midListArguments));
          int intLen = env->GetArrayLength(jargs);
          jintArray *ts = new jintArray[intLen];
          auto tmp = env->NewIntArray(1);
          jclass arrayClass = env->GetObjectClass(tmp);
          env->DeleteLocalRef(tmp);
          jobjectArray tensorShapes = env->NewObjectArray(intLen, arrayClass, NULL);
          for (int i = 0; i < intLen; ++i) {
            ts[i] = env->NewIntArray(ndims[i]);
            env->SetIntArrayRegion(
              ts[i], (jsize)0, (jsize)ndims[i], reinterpret_cast<int *>(shapes[i]));
            env->SetObjectArrayElement(tensorShapes, i, (jobject)(ts[i]));
          }
          jobjectArray ret = (jobjectArray)(env->CallObjectMethod(
            globalOpPropMap.at(key), midInferShape,
            numInput,
            tensorShapes));
          for (int i = 0; i < numInput; ++i) {
            jintArray jarr = reinterpret_cast<jintArray>(env->GetObjectArrayElement(ret, i));
            int len = env->GetArrayLength(jarr);
            jint *arr = env->GetIntArrayElements(jarr, NULL);
            ndims[i] = len;
            shapes[i] = new unsigned[len];
            for (int j = 0; j < len; ++j) shapes[i][j] = (unsigned)(arr[j]);
            env->DeleteLocalRef(jarr);
          }
          for (int i = 0; i < intLen; ++i) {
            env->DeleteLocalRef(ts[i]);
          }
          delete[] ts;
        }
        _jvm->DetachCurrentThread();
      }
      return success;
    };

    // infer_type callback
    auto opPropInferType = [](int numInput, int* types, void* state) {
      int success = true;
      std::string key(reinterpret_cast<char *>(state));
      if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
        LOG(WARNING) << "CustomOpProp: " << key << " not found";
        success = false;
      } else {
        JNIEnv *env;
        _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
        jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
        jmethodID midInferType = env->GetMethodID(opPropClass, "inferTypeEntry", "(I[I)[I");
        if (NULL == midInferType) {
          LOG(WARNING) << "could not find opProp method inferTypeEntry.";
          success = false;
        } else {
          jmethodID midListArguments = env->GetMethodID(
            opPropClass, "listArguments", "()[Ljava/lang/String;");
          jobjectArray jargs = (jobjectArray)(env->CallObjectMethod(
            globalOpPropMap.at(key), midListArguments));

          int intLen = env->GetArrayLength(jargs);
          jintArray ts = env->NewIntArray(intLen);
          int *tmp = new int[intLen];
          for (int i = 0; i < intLen; ++i) tmp[i] = types[i];
          env->SetIntArrayRegion(ts, (jsize)0, (jsize)intLen, tmp);

          jintArray ret = (jintArray)(env->CallObjectMethod(
            globalOpPropMap.at(key), midInferType,
            numInput,
            ts));
          jint *arr = env->GetIntArrayElements(ret, NULL);
          for (int i = 0; i < numInput; ++i) {
            types[i] = static_cast<int>(arr[i]);
          }

          delete[] tmp;
          env->ReleaseIntArrayElements(ret, arr, 0);
          env->DeleteLocalRef(ret);
          env->DeleteLocalRef(ts);
        }
        _jvm->DetachCurrentThread();
      }
      return success;
    };

    // create_operator callback
    auto opPropCreateOp = [](const char *ctx, int numInputs,
      unsigned **shapes, int *ndims, int *dtypes, MXCallbackList *ret, void *state) {
      int success = true;
      std::string key(reinterpret_cast<char *>(state));
      if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
        LOG(WARNING) << "CustomOpProp: " << key << " not found";
        success = false;
      } else {
        JNIEnv *env;
        _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
        jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
        jmethodID midCreateOp = env->GetMethodID(
          opPropClass, "createOperator", "(Ljava/lang/String;[[I[I)Lorg/apache/mxnet/CustomOp;");
        if (NULL == midCreateOp) {
          LOG(WARNING) << "could not find opProp method createOperator.";
          success = false;
        } else {
          jstring jctx = env->NewStringUTF(ctx);
          jintArray *ts = new jintArray[numInputs];
          auto tmp = env->NewIntArray(1);
          jclass arrayClass = env->GetObjectClass(tmp);
          env->DeleteLocalRef(tmp);
          jobjectArray inputShapes = env->NewObjectArray(numInputs, arrayClass, NULL);
          for (int i = 0; i < numInputs; ++i) {
            ts[i] = env->NewIntArray(ndims[i]);
            env->SetIntArrayRegion(
              ts[i], (jsize)0, (jsize)ndims[i], reinterpret_cast<int *>(shapes[i]));
            env->SetObjectArrayElement(inputShapes, i, (jobject)(ts[i]));
          }
          jintArray jdtypes = env->NewIntArray(numInputs);
          env->SetIntArrayRegion(jdtypes, (jsize)0, (jsize)numInputs, dtypes);
          // get operator
          jobject jOp = env->CallObjectMethod(globalOpPropMap.at(key), midCreateOp,
                                        jctx,
                                        inputShapes,
                                        jdtypes);
          env->DeleteLocalRef(jctx);
          for (int i = 0; i < numInputs; ++i) {
            env->DeleteLocalRef(ts[i]);
          }
          delete[] ts;

          std::unique_lock<std::mutex> lock(mutex_op);
          globalOpMap.insert({ key, env->NewGlobalRef(jOp) });
          lock.unlock();

          _jvm->DetachCurrentThread();

          // forward callback
          auto forwardEntry = [](int size, void **ptrs, int *tags,
            const int *reqs, const int isTrain, void *state) {
            std::string key(reinterpret_cast<char *>(state));
            int success = true;
            if (globalOpMap.find(key) == globalOpMap.end()) {
              LOG(WARNING) << "op: " << key << " not found";
              success = false;
            } else {
              JNIEnv *env;
              _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
              jclass opClass =  env->GetObjectClass(globalOpMap.at(key));
              jmethodID midForward = env->GetMethodID(opClass, "forwardEntry", "(I[J[I[IZ)Z");
              if (NULL == midForward) {
                LOG(WARNING) << "could not find op method forwardEntry.";
                success = false;
              } else {
                jintArray tagsArr = env->NewIntArray(size);
                env->SetIntArrayRegion(tagsArr, (jsize)0, (jsize)size, tags);
                int reqSize = 0;
                for (int i = 0; i < size; ++i) {
                  if (tags[i] == 1) reqSize++;
                }
                jintArray reqsArr = env->NewIntArray(reqSize);
                env->SetIntArrayRegion(reqsArr, (jsize)0, (jsize)reqSize, reqs);
                jlongArray ptrsArr = env->NewLongArray(size);
                env->SetLongArrayRegion(
                  ptrsArr, (jsize)0, (jsize)size, reinterpret_cast<jlong*>(ptrs));
#if MXNET_USE_CUDA
                mxnet::NDArray* tmp = reinterpret_cast<mxnet::NDArray*>(ptrs[0]);
                if (tmp->ctx().dev_type == mxnet::Context::kGPU
                  || tmp->ctx().dev_type == mxnet::Context::kCPUPinned) {
                  CUDA_CALL(cudaSetDevice(tmp->ctx().dev_id));
                }
#endif
                bool is_train =  true;
                if (isTrain == 0) is_train = false;
                success = env->CallBooleanMethod(globalOpMap.at(key), midForward,
                                                       size,
                                                       ptrsArr,
                                                       tagsArr,
                                                       reqsArr,
                                                       is_train);
                env->DeleteLocalRef(tagsArr);
                env->DeleteLocalRef(reqsArr);
                env->DeleteLocalRef(ptrsArr);
              }
              _jvm->DetachCurrentThread();
            }
            return success;
          };

          // backward callback
          auto backwardEntry = [](int size, void **ptrs, int *tags,
            const int *reqs, const int isTrain, void *state) {
            std::string key(reinterpret_cast<char *>(state));
            int success = true;
            if (globalOpMap.find(key) == globalOpMap.end()) {
              LOG(WARNING) << "op: " << key << " not found";
              success = false;
            } else {
              JNIEnv *env;
              _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
              jclass opClass = env->GetObjectClass(globalOpMap.at(key));
              jmethodID midBackward = env->GetMethodID(opClass, "backwardEntry", "(I[J[I[IZ)Z");
              if (NULL == midBackward) {
                LOG(WARNING) << "could not find op method backwardEntry.";
                success = false;
              } else {
                jintArray tagsArr = env->NewIntArray(size);
                env->SetIntArrayRegion(tagsArr, (jsize)0, (jsize)size, tags);

                int reqSize = 0;
                for (int i = 0; i < size; ++i) {
                  if (tags[i] == 2) reqSize++;
                }
                jintArray reqsArr = env->NewIntArray(reqSize);
                env->SetIntArrayRegion(reqsArr, (jsize)0, (jsize)reqSize, reqs);
                jlongArray ptrsArr = env->NewLongArray(size);
                env->SetLongArrayRegion(
                  ptrsArr, (jsize)0, (jsize)size, reinterpret_cast<jlong*>(ptrs));
                bool is_train =  true;
                if (isTrain == 0) is_train = false;
                success = env->CallBooleanMethod(globalOpMap.at(key), midBackward,
                                                       size,
                                                       ptrsArr,
                                                       tagsArr,
                                                       reqsArr,
                                                       is_train);
                env->DeleteLocalRef(tagsArr);
                env->DeleteLocalRef(reqsArr);
                env->DeleteLocalRef(ptrsArr);
              }
              _jvm->DetachCurrentThread();
            }
            return success;
          };

          // del callback
          auto delEntry = [](void *state) {
            std::string key(reinterpret_cast<char *>(state));
            int success = true;
            std::unique_lock<std::mutex> lock(mutex_op);
            if (globalOpMap.find(key) == globalOpMap.end()) {
              LOG(WARNING) << "op: " << key << " not found";
              success = false;
            } else {
              JNIEnv *env;
              _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
              env->DeleteGlobalRef(globalOpMap.at(key));
              _jvm->DetachCurrentThread();
              for (auto it = globalOpMap.begin(); it != globalOpMap.end(); ) {
                if (it->first == key) {
                  it = globalOpMap.erase(it);
                } else {
                  ++it;
                }
              }
            }
            lock.unlock();
            return success;
          };

          // TODO(eric): Memory leak here. Refactor later and delete in delEntry
          ret->num_callbacks = 3;
          ret->callbacks = new MXGenericCallback[ret->num_callbacks];
          ret->callbacks[kCustomOpDelete] =
            reinterpret_cast<int(*)(void)>(static_cast<int(*)(void*)>(delEntry));
          ret->callbacks[kCustomOpForward] =
            reinterpret_cast<int(*)(void)>(
              static_cast<int(*)(int, void**, int*, const int*, const int, void*)>(
                forwardEntry));
          ret->callbacks[kCustomOpBackward] =
            reinterpret_cast<int(*)(void)>(
              static_cast<int(*)(int, void**, int*, const int*, const int, void*)>(
                backwardEntry));
          ret->contexts = new void*[ret->num_callbacks];
          ret->contexts[kCustomOpDelete] = state;
          ret->contexts[kCustomOpForward] = state;
          ret->contexts[kCustomOpBackward] = state;
        }
      }
      return success;
    };

    // del callback
    auto opPropDel = [](void *state) {
      /*
       * This method seems to be called by the engine to clean up after multiple calls were made
       * to the creator lambda. The current creator function isn't allocating a new object but is
       * instead reinitializing the object which was created when register was called. This means
       * that there doesn't seem to be anything to clean up here (previous efforts were actually
       * deregistering the operator).
      */
      return 1;
    };

    // TODO(eric): Memory leak. Missing infertype.
    ret->num_callbacks = 8;
    ret->callbacks = new MXGenericCallback[ret->num_callbacks];
    ret->callbacks[kCustomOpPropDelete] =
      reinterpret_cast<int(*)(void)>(
        static_cast<int(*)(void*)>(opPropDel));
    ret->callbacks[kCustomOpPropListArguments] =
      reinterpret_cast<int(*)(void)>(
        static_cast<int(*)(char***, void*)>(opPropListArgument));
    ret->callbacks[kCustomOpPropListOutputs] =
      reinterpret_cast<int(*)(void)>(
        static_cast<int(*)(char***, void*)>(opPropListOutputs));
    ret->callbacks[kCustomOpPropListAuxiliaryStates] =
      reinterpret_cast<int(*)(void)>(
        static_cast<int(*)(char***, void*)>(opPropListAuxStates));
    ret->callbacks[kCustomOpPropInferShape] =
      reinterpret_cast<int(*)(void)>(
        static_cast<int (*)(int, int*, unsigned**, void*)>(opPropInferShape));
    ret->callbacks[kCustomOpPropDeclareBackwardDependency] =
      reinterpret_cast<int(*)(void)>(
        static_cast<int(*)(const int*, const int*, const int*, int* num_deps, int**, void*)>(
          opPropDeclareBkDep));
    ret->callbacks[kCustomOpPropCreateOperator] =
      reinterpret_cast<int(*)(void)>(
        static_cast<int(*)(const char*, int, unsigned**, int*, int*, MXCallbackList*, void*)>(
          opPropCreateOp));
    ret->callbacks[kCustomOpPropInferType] =
      reinterpret_cast<int(*)(void)>(
        static_cast<int(*)(int, int*, void*)>(opPropInferType));

    ret->contexts = new void*[ret->num_callbacks];
    ret->contexts[kCustomOpPropDelete] =
      reinterpret_cast<void *>(const_cast<char *>(opType));
    ret->contexts[kCustomOpPropListArguments] =
      reinterpret_cast<void *>(const_cast<char *>(opType));
    ret->contexts[kCustomOpPropListOutputs] =
      reinterpret_cast<void *>(const_cast<char *>(opType));
    ret->contexts[kCustomOpPropListAuxiliaryStates] =
      reinterpret_cast<void *>(const_cast<char *>(opType));
    ret->contexts[kCustomOpPropInferShape] =
      reinterpret_cast<void *>(const_cast<char *>(opType));
    ret->contexts[kCustomOpPropDeclareBackwardDependency] =
      reinterpret_cast<void *>(const_cast<char *>(opType));
    ret->contexts[kCustomOpPropCreateOperator] =
      reinterpret_cast<void *>(const_cast<char *>(opType));
    ret->contexts[kCustomOpPropInferType] =
      reinterpret_cast<void *>(const_cast<char *>(opType));
    return success;
  };

  CustomOpPropCreator creator =
    static_cast<int(*)(const char*, const int, const char**, const char**, MXCallbackList*)>(
      creatorLambda);
  return MXCustomOpRegister(regName, creator);
}

struct JNIString {
  JNIEnv *env_;
  jstring java_string_;
  const char *str_;
  inline JNIString(JNIEnv *env, const jstring& java_string)
  : env_(env)
    , java_string_(java_string) {
    str_ = env_->GetStringUTFChars(java_string_, 0);
  }
  inline ~JNIString() {
    if (str_) {
      env_->ReleaseStringUTFChars(java_string_, str_);
    }
  }
  inline const char *operator ()() const {
    return str_;
  }
};

struct JNIStringArray {
  std::vector<std::unique_ptr<JNIString>> jni_strings_;
  std::vector<const char *> strings_;
  JNIStringArray(JNIEnv *env, const jobjectArray& stringArray) {
    const int count = env->GetArrayLength(stringArray);
    jni_strings_.reserve(count);
    strings_.reserve(count);
    for (int i = 0; i < count; ++i) {
      jstring string = static_cast<jstring>(env->GetObjectArrayElement(stringArray, i));
      jni_strings_.emplace_back(std::unique_ptr<JNIString>(new JNIString(env, string)));
      strings_.emplace_back((*jni_strings_.rbegin())->str_);
    }
  }
  const char * const* operator ()() const { return &strings_[0]; }
};

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetProfilerConfig
  (JNIEnv *env, jobject obj, jobjectArray keys, jobjectArray vals) {
  const int stringCount = env->GetArrayLength(keys);
  CHECK_EQ(stringCount, env->GetArrayLength(vals)) << "Key and value arrays must be the same size";

  JNIStringArray the_keys(env, keys), the_vals(env, vals);

  const int ret = MXSetProfilerConfig(stringCount, the_keys(), the_vals());
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetProfilerState
  (JNIEnv *env, jobject obj, jint jstate) {
  return MXSetProfilerState(jstate);
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile
  (JNIEnv *env, jobject obj, jint finished) {
  return MXDumpProfile(finished);
}

// Numpy
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyShape
  (JNIEnv *env, jobject obj, jobject compatibleRef) {
  bool isNumpyShape;
  int ret = MXIsNumpyShape(&isNumpyShape);
  SetIntField(env, compatibleRef, static_cast<int>(isNumpyShape));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyShape
  (JNIEnv *env, jobject obj, jint isNpComp, jobject prevRef) {
  int prev;
  int ret = MXSetIsNumpyShape(isNpComp, &prev);
  SetIntField(env, prevRef, prev);
  return ret;
}
